From 3a7abeb0cd4a6e3b99dce36aafe3951add501b7a Mon Sep 17 00:00:00 2001 From: wojciechos Date: Tue, 10 Dec 2024 21:52:49 +0100 Subject: [PATCH 01/10] Skip error logs for FGW responses with NOT_RECEIVED status (#2303) * Add NotReceived case handling in adaptTransactionStatus --------- Co-authored-by: Rian Hughes --- rpc/transaction.go | 9 +++++++-- starknet/compiler/rust/src/lib.rs | 15 +++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/rpc/transaction.go b/rpc/transaction.go index 8e3c3956f5..8157b4caa8 100644 --- a/rpc/transaction.go +++ b/rpc/transaction.go @@ -565,6 +565,8 @@ func (h *Handler) AddTransaction(ctx context.Context, tx BroadcastedTransaction) }, nil } +var errTransactionNotFound = fmt.Errorf("transaction not found") + func (h *Handler) TransactionStatus(ctx context.Context, hash felt.Felt) (*TransactionStatus, *jsonrpc.Error) { receipt, txErr := h.TransactionReceiptByHash(hash) switch txErr { @@ -585,10 +587,11 @@ func (h *Handler) TransactionStatus(ctx context.Context, hash felt.Felt) (*Trans status, err := adaptTransactionStatus(txStatus) if err != nil { - h.log.Errorw("Failed to adapt transaction status", "err", err) + if !errors.Is(err, errTransactionNotFound) { + h.log.Errorw("Failed to adapt transaction status", "err", err) + } return nil, ErrTxnHashNotFound } - return status, nil } return nil, txErr @@ -751,6 +754,8 @@ func adaptTransactionStatus(txStatus *starknet.TransactionStatus) (*TransactionS status.Finality = TxnStatusAcceptedOnL2 case starknet.Received: status.Finality = TxnStatusReceived + case starknet.NotReceived: + return nil, errTransactionNotFound default: return nil, fmt.Errorf("unknown finality status: %v", finalityStatus) } diff --git a/starknet/compiler/rust/src/lib.rs b/starknet/compiler/rust/src/lib.rs index a1027a3e3b..fa45765cbc 100644 --- a/starknet/compiler/rust/src/lib.rs +++ b/starknet/compiler/rust/src/lib.rs @@ -1,6 +1,8 @@ -use cairo_lang_starknet_classes::casm_contract_class::{CasmContractClass, StarknetSierraCompilationError}; +use cairo_lang_starknet_classes::casm_contract_class::{ + CasmContractClass, StarknetSierraCompilationError, +}; use std::ffi::{c_char, CStr, CString}; -use std::panic::{self,AssertUnwindSafe}; +use std::panic::{self, AssertUnwindSafe}; #[no_mangle] #[allow(clippy::not_unsafe_ptr_arg_deref)] @@ -25,9 +27,14 @@ pub extern "C" fn compileSierraToCasm(sierra_json: *const c_char, result: *mut * } }; - let mut casm_class_result: Option> = None; + let mut casm_class_result: Option> = + None; let compilation_result = panic::catch_unwind(AssertUnwindSafe(|| { - casm_class_result = Some(CasmContractClass::from_contract_class(sierra_class, true, usize::MAX)); + casm_class_result = Some(CasmContractClass::from_contract_class( + sierra_class, + true, + usize::MAX, + )); })); if let Err(_) = compilation_result { unsafe { From e75e504eea82d633fa6ff063fbbe036452a11674 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 11 Dec 2024 07:35:16 +0000 Subject: [PATCH 02/10] Bump nanoid from 3.3.7 to 3.3.8 in /docs in the npm_and_yarn group across 1 directory (#2316) Bump nanoid in /docs in the npm_and_yarn group across 1 directory Bumps the npm_and_yarn group with 1 update in the /docs directory: [nanoid](https://github.com/ai/nanoid). Updates `nanoid` from 3.3.7 to 3.3.8 - [Release notes](https://github.com/ai/nanoid/releases) - [Changelog](https://github.com/ai/nanoid/blob/main/CHANGELOG.md) - [Commits](https://github.com/ai/nanoid/compare/3.3.7...3.3.8) --- updated-dependencies: - dependency-name: nanoid dependency-type: indirect dependency-group: npm_and_yarn ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/package-lock.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/package-lock.json b/docs/package-lock.json index a7efb61e7b..c5da87d581 100644 --- a/docs/package-lock.json +++ b/docs/package-lock.json @@ -12175,9 +12175,9 @@ } }, "node_modules/nanoid": { - "version": "3.3.7", - "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.7.tgz", - "integrity": "sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g==", + "version": "3.3.8", + "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.8.tgz", + "integrity": "sha512-WNLf5Sd8oZxOm+TzppcYk8gVOgP+l58xNy58D0nbUnOxOWRWvlcCV4kUF7ltmI6PsrLl/BgKEyS4mqsGChFN0w==", "funding": [ { "type": "github", From 8862de1088a2e98c1bd018f799e12b6c96200c80 Mon Sep 17 00:00:00 2001 From: wojciechos Date: Wed, 11 Dec 2024 11:20:02 +0100 Subject: [PATCH 03/10] Improve binary build workflow for cross-platform releases (#2315) - Add proper architecture handling in matrix configuration - Implement caching for Go modules and Rust dependencies - Streamline dependency installation for both Linux and macOS - Improve binary artifact handling and checksums - Add retention policy for build artifacts - Split build steps for better clarity and maintainability This update ensures more reliable and efficient binary builds across all supported platforms. --- .github/workflows/build-binaries.yml | 57 ++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/.github/workflows/build-binaries.yml b/.github/workflows/build-binaries.yml index c85c08263c..3e6ae819b4 100644 --- a/.github/workflows/build-binaries.yml +++ b/.github/workflows/build-binaries.yml @@ -18,9 +18,13 @@ jobs: matrix: include: - os: ubuntu-latest + arch: amd64 - os: macos-13 + arch: amd64 - os: ubuntu-arm64-4-core + arch: arm64 - os: macos-latest + arch: arm64 runs-on: ${{ matrix.os }} steps: @@ -29,35 +33,55 @@ jobs: with: fetch-depth: 0 + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + + - name: Set up Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache Rust dependencies + uses: Swatinem/rust-cache@v2 + with: + workspaces: | + vm/rust + core/rust + starknet/compiler/rust + - name: Get latest tag run: echo "TAG=$(git describe --tags)" >> $GITHUB_ENV - name: Get artifact name - run: echo "ARTIFACT_NAME=juno-${{ env.TAG }}-${{ runner.os }}-$(uname -m)" >> $GITHUB_ENV + run: | + OS_NAME=$([ "${{ runner.os }}" == "macOS" ] && echo "darwin" || echo "linux") + echo "ARTIFACT_NAME=juno-${{ env.TAG }}-${OS_NAME}-${{ matrix.arch }}" >> $GITHUB_ENV - name: Install dependencies (Linux) if: runner.os == 'Linux' - run: sudo apt-get update -qq && sudo apt-get install -y upx-ucl build-essential cargo git golang libjemalloc-dev libjemalloc2 -y + run: | + sudo apt-get update -qq + sudo apt-get install -y upx-ucl libjemalloc-dev libjemalloc2 libbz2-dev - name: Install dependencies (macOS) if: runner.os == 'macOS' - run: brew install cargo-c jemalloc - - - name: Set up Go - uses: actions/setup-go@v5 - with: - go-version-file: go.mod + run: brew install jemalloc - - name: Build Juno + - name: Build binary + run: make juno + + - name: Compress binary (Linux) + if: runner.os == 'Linux' run: | - make juno - if [[ "${{ runner.os }}" != "macOS" ]]; then - upx build/juno - fi + upx build/juno mv build/juno ${{ env.ARTIFACT_NAME }} - - name: Generate Checksum - id: checksum + - name: Prepare binary (macOS) + if: runner.os == 'macOS' + run: mv build/juno ${{ env.ARTIFACT_NAME }} + + - name: Generate checksum run: | if [[ "${{ runner.os }}" == "macOS" ]]; then shasum -a 256 ${{ env.ARTIFACT_NAME }} > ${{ env.ARTIFACT_NAME }}.sha256 @@ -65,10 +89,11 @@ jobs: sha256sum ${{ env.ARTIFACT_NAME }} > ${{ env.ARTIFACT_NAME }}.sha256 fi - - name: Upload Artifact + - name: Upload artifact uses: actions/upload-artifact@v4 with: name: ${{ env.ARTIFACT_NAME }} path: | ${{ env.ARTIFACT_NAME }} ${{ env.ARTIFACT_NAME }}.sha256 + retention-days: 30 \ No newline at end of file From 60e8cc9472f6eb79b7dc0021c7413b88ae7f3948 Mon Sep 17 00:00:00 2001 From: AnavarKh <108727035+AnavarKh@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:04:31 +0530 Subject: [PATCH 04/10] Update download link for Juno snapshots from dev to io in Readme file (#2314) --- README.md | 14 +++++++------- docs/docs/snapshots.md | 14 +++++++------- docs/versioned_docs/version-0.11.0/snapshots.md | 6 +++--- docs/versioned_docs/version-0.11.8/snapshots.md | 14 +++++++------- docs/versioned_docs/version-0.12.4/snapshots.md | 14 +++++++------- docs/versioned_docs/version-0.8.0/snapshots.md | 6 +++--- docs/versioned_docs/version-0.9.3/snapshots.md | 6 +++--- 7 files changed, 37 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 977d3c1b79..2773fa248d 100644 --- a/README.md +++ b/README.md @@ -112,32 +112,32 @@ Use the provided snapshots to quickly sync your Juno node with the current state | Version | Download Link | | ------- | ------------- | -| **>=v0.9.2** | [**juno_mainnet.tar**](https://juno-snapshots.nethermind.dev/files/mainnet/latest) | +| **>=v0.9.2** | [**juno_mainnet.tar**](https://juno-snapshots.nethermind.io/files/mainnet/latest) | #### Sepolia | Version | Download Link | | ------- | ------------- | -| **>=v0.9.2** | [**juno_sepolia.tar**](https://juno-snapshots.nethermind.dev/files/sepolia/latest) | +| **>=v0.9.2** | [**juno_sepolia.tar**](https://juno-snapshots.nethermind.io/files/sepolia/latest) | #### Sepolia-Integration | Version | Download Link | | ------- | ------------- | -| **>=v0.9.2** | [**juno_sepolia_integration.tar**](https://juno-snapshots.nethermind.dev/files/sepolia-integration/latest) | +| **>=v0.9.2** | [**juno_sepolia_integration.tar**](https://juno-snapshots.nethermind.io/files/sepolia-integration/latest) | ### Getting the size for each snapshot ```console $date Thu 1 Aug 2024 09:49:30 BST -$curl -s -I -L https://juno-snapshots.nethermind.dev/files/mainnet/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' +$curl -s -I -L https://juno-snapshots.nethermind.io/files/mainnet/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' 172.47 GB -$curl -s -I -L https://juno-snapshots.nethermind.dev/files/sepolia/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' +$curl -s -I -L https://juno-snapshots.nethermind.io/files/sepolia/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' 5.67 GB -$curl -s -I -L https://juno-snapshots.nethermind.dev/files/sepolia-integration/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' +$curl -s -I -L https://juno-snapshots.nethermind.io/files/sepolia-integration/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' 2.4 GB ``` @@ -148,7 +148,7 @@ $curl -s -I -L https://juno-snapshots.nethermind.dev/files/sepolia-integration/l Fetch the snapshot from the provided URL: ```bash - wget -O juno_mainnet.tar https://juno-snapshots.nethermind.dev/files/mainnet/latest + wget -O juno_mainnet.tar https://juno-snapshots.nethermind.io/files/mainnet/latest ``` 2. **Prepare Directory** diff --git a/docs/docs/snapshots.md b/docs/docs/snapshots.md index b5306e6d41..f272b8a174 100644 --- a/docs/docs/snapshots.md +++ b/docs/docs/snapshots.md @@ -10,19 +10,19 @@ You can download a snapshot of the Juno database to reduce the network syncing t | Version | Download Link | | ------- | ------------- | -| **>=v0.9.2** | [**juno_mainnet.tar**](https://juno-snapshots.nethermind.dev/files/mainnet/latest) | +| **>=v0.9.2** | [**juno_mainnet.tar**](https://juno-snapshots.nethermind.io/files/mainnet/latest) | ## Sepolia | Version | Download Link | | ------- | ------------- | -| **>=v0.9.2** | [**juno_sepolia.tar**](https://juno-snapshots.nethermind.dev/files/sepolia/latest) | +| **>=v0.9.2** | [**juno_sepolia.tar**](https://juno-snapshots.nethermind.io/files/sepolia/latest) | ## Sepolia-Integration | Version | Download Link | | ------- | ------------- | -| **>=v0.9.2** | [**juno_sepolia_integration.tar**](https://juno-snapshots.nethermind.dev/files/sepolia-integration/latest) | +| **>=v0.9.2** | [**juno_sepolia_integration.tar**](https://juno-snapshots.nethermind.io/files/sepolia-integration/latest) | ## Getting snapshot sizes @@ -30,13 +30,13 @@ You can download a snapshot of the Juno database to reduce the network syncing t $date Thu 1 Aug 2024 09:49:30 BST -$curl -s -I -L https://juno-snapshots.nethermind.dev/files/mainnet/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' +$curl -s -I -L https://juno-snapshots.nethermind.io/files/mainnet/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' 172.47 GB -$curl -s -I -L https://juno-snapshots.nethermind.dev/files/sepolia/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' +$curl -s -I -L https://juno-snapshots.nethermind.io/files/sepolia/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' 5.67 GB -$curl -s -I -L https://juno-snapshots.nethermind.dev/files/sepolia-integration/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' +$curl -s -I -L https://juno-snapshots.nethermind.io/files/sepolia-integration/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' 2.4 GB ``` @@ -47,7 +47,7 @@ $curl -s -I -L https://juno-snapshots.nethermind.dev/files/sepolia-integration/l First, download a snapshot from one of the provided URLs: ```bash -wget -O juno_mainnet.tar https://juno-snapshots.nethermind.dev/files/mainnet/latest +wget -O juno_mainnet.tar https://juno-snapshots.nethermind.io/files/mainnet/latest ``` ### 2. Prepare a directory diff --git a/docs/versioned_docs/version-0.11.0/snapshots.md b/docs/versioned_docs/version-0.11.0/snapshots.md index 1dffd37dbb..a1b23e98e8 100644 --- a/docs/versioned_docs/version-0.11.0/snapshots.md +++ b/docs/versioned_docs/version-0.11.0/snapshots.md @@ -11,14 +11,14 @@ After downloading a snapshot and starting a Juno node, only recent blocks must b | Version | Size | Block | Download Link | | ------- | ---- | ----- | ------------- | -| **>=v0.9.2** | **182 GB** | **640855** | [**juno_mainnet.tar**](https://juno-snapshots.nethermind.dev/mainnet/juno_mainnet_v0.11.7_640855.tar) | +| **>=v0.9.2** | **182 GB** | **640855** | [**juno_mainnet.tar**](https://juno-snapshots.nethermind.io/mainnet/juno_mainnet_v0.11.7_640855.tar) | ## Sepolia | Version | Size | Block | Download Link | | ------- | ---- | ----- | ------------- | -| **>=v0.9.2** | **5 GB** | **66477** | [**juno_sepolia.tar**](https://juno-snapshots.nethermind.dev/sepolia/juno_sepolia_v0.11.7_66477.tar) | +| **>=v0.9.2** | **5 GB** | **66477** | [**juno_sepolia.tar**](https://juno-snapshots.nethermind.io/sepolia/juno_sepolia_v0.11.7_66477.tar) | ## Run Juno Using Snapshot @@ -27,7 +27,7 @@ After downloading a snapshot and starting a Juno node, only recent blocks must b Fetch a snapshot from one of the provided URLs: ```bash - wget -O juno_mainnet.tar https://juno-snapshots.nethermind.dev/mainnet/juno_mainnet_v0.11.7_640855.tar + wget -O juno_mainnet.tar https://juno-snapshots.nethermind.io/mainnet/juno_mainnet_v0.11.7_640855.tar ``` 2. **Prepare Directory** diff --git a/docs/versioned_docs/version-0.11.8/snapshots.md b/docs/versioned_docs/version-0.11.8/snapshots.md index cff4419224..b5951b4fda 100644 --- a/docs/versioned_docs/version-0.11.8/snapshots.md +++ b/docs/versioned_docs/version-0.11.8/snapshots.md @@ -10,32 +10,32 @@ You can download a snapshot of the Juno database to reduce the network syncing t | Version | Download Link | | ------- | ------------- | -| **>=v0.9.2** | [**juno_mainnet.tar**](https://juno-snapshots.nethermind.dev/files/mainnet/latest) | +| **>=v0.9.2** | [**juno_mainnet.tar**](https://juno-snapshots.nethermind.io/files/mainnet/latest) | ## Sepolia | Version | Download Link | | ------- | ------------- | -| **>=v0.9.2** | [**juno_sepolia.tar**](https://juno-snapshots.nethermind.dev/files/sepolia/latest) | +| **>=v0.9.2** | [**juno_sepolia.tar**](https://juno-snapshots.nethermind.io/files/sepolia/latest) | ## Sepolia-Integration | Version | Download Link | | ------- | ------------- | -| **>=v0.9.2** | [**juno_sepolia_integration.tar**](https://juno-snapshots.nethermind.dev/files/sepolia-integration/latest) | +| **>=v0.9.2** | [**juno_sepolia_integration.tar**](https://juno-snapshots.nethermind.io/files/sepolia-integration/latest) | ### Getting the size for each snapshot ```console $date Thu 1 Aug 2024 09:49:30 BST -$curl -s -I -L https://juno-snapshots.nethermind.dev/files/mainnet/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' +$curl -s -I -L https://juno-snapshots.nethermind.io/files/mainnet/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' 172.47 GB -$curl -s -I -L https://juno-snapshots.nethermind.dev/files/sepolia/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' +$curl -s -I -L https://juno-snapshots.nethermind.io/files/sepolia/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' 5.67 GB -$curl -s -I -L https://juno-snapshots.nethermind.dev/files/sepolia-integration/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' +$curl -s -I -L https://juno-snapshots.nethermind.io/files/sepolia-integration/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' 2.4 GB ``` @@ -46,7 +46,7 @@ $curl -s -I -L https://juno-snapshots.nethermind.dev/files/sepolia-integration/l First, download a snapshot from one of the provided URLs: ```bash -wget -O juno_mainnet.tar https://juno-snapshots.nethermind.dev/files/mainnet/latest +wget -O juno_mainnet.tar https://juno-snapshots.nethermind.io/files/mainnet/latest ``` ### 2. Prepare a directory diff --git a/docs/versioned_docs/version-0.12.4/snapshots.md b/docs/versioned_docs/version-0.12.4/snapshots.md index b5306e6d41..f272b8a174 100644 --- a/docs/versioned_docs/version-0.12.4/snapshots.md +++ b/docs/versioned_docs/version-0.12.4/snapshots.md @@ -10,19 +10,19 @@ You can download a snapshot of the Juno database to reduce the network syncing t | Version | Download Link | | ------- | ------------- | -| **>=v0.9.2** | [**juno_mainnet.tar**](https://juno-snapshots.nethermind.dev/files/mainnet/latest) | +| **>=v0.9.2** | [**juno_mainnet.tar**](https://juno-snapshots.nethermind.io/files/mainnet/latest) | ## Sepolia | Version | Download Link | | ------- | ------------- | -| **>=v0.9.2** | [**juno_sepolia.tar**](https://juno-snapshots.nethermind.dev/files/sepolia/latest) | +| **>=v0.9.2** | [**juno_sepolia.tar**](https://juno-snapshots.nethermind.io/files/sepolia/latest) | ## Sepolia-Integration | Version | Download Link | | ------- | ------------- | -| **>=v0.9.2** | [**juno_sepolia_integration.tar**](https://juno-snapshots.nethermind.dev/files/sepolia-integration/latest) | +| **>=v0.9.2** | [**juno_sepolia_integration.tar**](https://juno-snapshots.nethermind.io/files/sepolia-integration/latest) | ## Getting snapshot sizes @@ -30,13 +30,13 @@ You can download a snapshot of the Juno database to reduce the network syncing t $date Thu 1 Aug 2024 09:49:30 BST -$curl -s -I -L https://juno-snapshots.nethermind.dev/files/mainnet/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' +$curl -s -I -L https://juno-snapshots.nethermind.io/files/mainnet/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' 172.47 GB -$curl -s -I -L https://juno-snapshots.nethermind.dev/files/sepolia/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' +$curl -s -I -L https://juno-snapshots.nethermind.io/files/sepolia/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' 5.67 GB -$curl -s -I -L https://juno-snapshots.nethermind.dev/files/sepolia-integration/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' +$curl -s -I -L https://juno-snapshots.nethermind.io/files/sepolia-integration/latest | gawk -v IGNORECASE=1 '/^Content-Length/ { printf "%.2f GB\n", $2/1024/1024/1024 }' 2.4 GB ``` @@ -47,7 +47,7 @@ $curl -s -I -L https://juno-snapshots.nethermind.dev/files/sepolia-integration/l First, download a snapshot from one of the provided URLs: ```bash -wget -O juno_mainnet.tar https://juno-snapshots.nethermind.dev/files/mainnet/latest +wget -O juno_mainnet.tar https://juno-snapshots.nethermind.io/files/mainnet/latest ``` ### 2. Prepare a directory diff --git a/docs/versioned_docs/version-0.8.0/snapshots.md b/docs/versioned_docs/version-0.8.0/snapshots.md index c52ac639f1..73964bf56e 100644 --- a/docs/versioned_docs/version-0.8.0/snapshots.md +++ b/docs/versioned_docs/version-0.8.0/snapshots.md @@ -11,13 +11,13 @@ After downloading a snapshot and starting a Juno node, only recent blocks must b | Version | Size | Block | Download Link | | ------- | ---- | ----- | ------------- | -| **>=v0.6.0** | **121 GB** | **449406** | [**juno_mainnet.tar**](https://juno-snapshots.nethermind.dev/mainnet/juno_mainnet_v0.7.5_449406.tar) | +| **>=v0.6.0** | **121 GB** | **449406** | [**juno_mainnet.tar**](https://juno-snapshots.nethermind.io/mainnet/juno_mainnet_v0.7.5_449406.tar) | ## Goerli | Version | Size | Block | Download Link | | ------- | ---- | ----- | ------------- | -| **>=v0.6.0** | **41.4 GB** | **911580** | [**juno_goerli.tar**](https://juno-snapshots.nethermind.dev/goerli/juno_goerli_v0.7.5_911580.tar) | +| **>=v0.6.0** | **41.4 GB** | **911580** | [**juno_goerli.tar**](https://juno-snapshots.nethermind.io/goerli/juno_goerli_v0.7.5_911580.tar) | ## Run Juno Using Snapshot @@ -26,7 +26,7 @@ After downloading a snapshot and starting a Juno node, only recent blocks must b Fetch a snapshot from one of the provided URLs: ```bash - wget -O juno_mainnet.tar https://juno-snapshots.nethermind.dev/mainnet/juno_mainnet_v0.7.5_449406.tar + wget -O juno_mainnet.tar https://juno-snapshots.nethermind.io/mainnet/juno_mainnet_v0.7.5_449406.tar ``` 2. **Prepare Directory** diff --git a/docs/versioned_docs/version-0.9.3/snapshots.md b/docs/versioned_docs/version-0.9.3/snapshots.md index b0258f0249..894d72c6bf 100644 --- a/docs/versioned_docs/version-0.9.3/snapshots.md +++ b/docs/versioned_docs/version-0.9.3/snapshots.md @@ -11,13 +11,13 @@ After downloading a snapshot and starting a Juno node, only recent blocks must b | Version | Size | Block | Download Link | | ------- | ---- | ----- | ------------- | -| **>=v0.9.2** | **156 GB** | **519634** | [**juno_mainnet.tar**](https://juno-snapshots.nethermind.dev/mainnet/juno_mainnet_v0.9.3_519634.tar) | +| **>=v0.9.2** | **156 GB** | **519634** | [**juno_mainnet.tar**](https://juno-snapshots.nethermind.io/mainnet/juno_mainnet_v0.9.3_519634.tar) | ## Goerli | Version | Size | Block | Download Link | | ------- | ---- | ----- | ------------- | -| **>=v0.6.0** | **41.4 GB** | **911580** | [**juno_goerli.tar**](https://juno-snapshots.nethermind.dev/goerli/juno_goerli_v0.7.5_911580.tar) | +| **>=v0.6.0** | **41.4 GB** | **911580** | [**juno_goerli.tar**](https://juno-snapshots.nethermind.io/goerli/juno_goerli_v0.7.5_911580.tar) | ## Run Juno Using Snapshot @@ -26,7 +26,7 @@ After downloading a snapshot and starting a Juno node, only recent blocks must b Fetch a snapshot from one of the provided URLs: ```bash - wget -O juno_mainnet.tar https://juno-snapshots.nethermind.dev/mainnet/juno_mainnet_v0.9.3_519634.tar + wget -O juno_mainnet.tar https://juno-snapshots.nethermind.io/mainnet/juno_mainnet_v0.9.3_519634.tar ``` 2. **Prepare Directory** From 91d0f8e87c454d989273022ffc43d6a4040b71e2 Mon Sep 17 00:00:00 2001 From: Kirill Date: Wed, 11 Dec 2024 16:01:10 +0400 Subject: [PATCH 05/10] Add schema_version to output of db info command (#2309) --- cmd/juno/dbcmd.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/cmd/juno/dbcmd.go b/cmd/juno/dbcmd.go index 4fe5cd3a81..8d80ae1ccc 100644 --- a/cmd/juno/dbcmd.go +++ b/cmd/juno/dbcmd.go @@ -10,6 +10,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/pebble" + "github.com/NethermindEth/juno/migration" "github.com/NethermindEth/juno/utils" "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" @@ -21,6 +22,7 @@ const ( type DBInfo struct { Network string `json:"network"` + SchemaVersion uint64 `json:"schema_version"` ChainHeight uint64 `json:"chain_height"` LatestBlockHash *felt.Felt `json:"latest_block_hash"` LatestStateRoot *felt.Felt `json:"latest_state_root"` @@ -84,7 +86,7 @@ func dbInfo(cmd *cobra.Command, args []string) error { defer database.Close() chain := blockchain.New(database, nil) - info := DBInfo{} + var info DBInfo // Get the latest block information headBlock, err := chain.Head() @@ -97,6 +99,12 @@ func dbInfo(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to get the state update: %v", err) } + schemaMeta, err := migration.SchemaMetadata(database) + if err != nil { + return fmt.Errorf("failed to get schema metadata: %v", err) + } + + info.SchemaVersion = schemaMeta.Version info.Network = getNetwork(headBlock, stateUpdate.StateDiff) info.ChainHeight = headBlock.Number info.LatestBlockHash = headBlock.Hash From 8bf9be9fe9ac4d1dc279dd77bdd4c2e7c5028a4a Mon Sep 17 00:00:00 2001 From: Rian Hughes Date: Wed, 11 Dec 2024 14:11:22 +0200 Subject: [PATCH 06/10] update invoke v3 txn validation to require sender_address (#2308) --- rpc/transaction.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rpc/transaction.go b/rpc/transaction.go index 8157b4caa8..c357f345c6 100644 --- a/rpc/transaction.go +++ b/rpc/transaction.go @@ -210,7 +210,7 @@ type Transaction struct { ContractAddressSalt *felt.Felt `json:"contract_address_salt,omitempty" validate:"required_if=Type DEPLOY,required_if=Type DEPLOY_ACCOUNT"` ClassHash *felt.Felt `json:"class_hash,omitempty" validate:"required_if=Type DEPLOY,required_if=Type DEPLOY_ACCOUNT"` ConstructorCallData *[]*felt.Felt `json:"constructor_calldata,omitempty" validate:"required_if=Type DEPLOY,required_if=Type DEPLOY_ACCOUNT"` - SenderAddress *felt.Felt `json:"sender_address,omitempty" validate:"required_if=Type DECLARE,required_if=Type INVOKE Version 0x1"` + SenderAddress *felt.Felt `json:"sender_address,omitempty" validate:"required_if=Type DECLARE,required_if=Type INVOKE Version 0x1,required_if=Type INVOKE Version 0x3"` Signature *[]*felt.Felt `json:"signature,omitempty" validate:"required"` CallData *[]*felt.Felt `json:"calldata,omitempty" validate:"required_if=Type INVOKE"` EntryPointSelector *felt.Felt `json:"entry_point_selector,omitempty" validate:"required_if=Type INVOKE Version 0x0"` From 0a21162f7f5a06951f95f5d4c7a748361cd3b29c Mon Sep 17 00:00:00 2001 From: Daniil Ankushin Date: Thu, 12 Dec 2024 00:04:08 +0700 Subject: [PATCH 07/10] Remove unused code (#2317) --- p2p/p2p.go | 101 +--- p2p/p2p_test.go | 142 ------ p2p/starknet/handlers.go | 10 +- p2p/starknet/starknet_test.go | 909 ---------------------------------- p2p/sync.go | 6 - 5 files changed, 7 insertions(+), 1161 deletions(-) delete mode 100644 p2p/starknet/starknet_test.go diff --git a/p2p/p2p.go b/p2p/p2p.go index 49633f49ee..f0b54c3381 100644 --- a/p2p/p2p.go +++ b/p2p/p2p.go @@ -7,7 +7,6 @@ import ( "fmt" "math/rand" "strings" - "sync" "time" "github.com/Masterminds/semver/v3" @@ -21,7 +20,6 @@ import ( pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/crypto/pb" - "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" @@ -43,10 +41,8 @@ type Service struct { handler *starknet.Handler log utils.SimpleLogger - dht *dht.IpfsDHT - pubsub *pubsub.PubSub - topics map[string]*pubsub.Topic - topicsLock sync.RWMutex + dht *dht.IpfsDHT + pubsub *pubsub.PubSub synchroniser *syncService gossipTracer *gossipTracer @@ -157,7 +153,6 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai network: snNetwork, dht: p2pdht, feederNode: feederNode, - topics: make(map[string]*pubsub.Topic), handler: starknet.NewHandler(bc, log), database: database, } @@ -204,34 +199,6 @@ func privateKey(privKeyStr string) (crypto.PrivKey, error) { return prvKey, nil } -func (s *Service) SubscribePeerConnectednessChanged(ctx context.Context) (<-chan event.EvtPeerConnectednessChanged, error) { - ch := make(chan event.EvtPeerConnectednessChanged) - sub, err := s.host.EventBus().Subscribe(&event.EvtPeerConnectednessChanged{}) - if err != nil { - return nil, err - } - - go func() { - for { - select { - case <-ctx.Done(): - if err = sub.Close(); err != nil { - s.log.Warnw("Failed to close subscription", "err", err) - } - close(ch) - return - case evnt := <-sub.Out(): - typedEvnt := evnt.(event.EvtPeerConnectednessChanged) - if typedEvnt.Connectedness == network.Connected { - ch <- typedEvnt - } - } - } - }() - - return ch, nil -} - // Run starts the p2p service. Calling any other function before run is undefined behaviour func (s *Service) Run(ctx context.Context) error { defer func() { @@ -336,70 +303,6 @@ func (s *Service) NewStream(ctx context.Context, pids ...protocol.ID) (network.S } } -func (s *Service) joinTopic(topic string) (*pubsub.Topic, error) { - existingTopic := func() *pubsub.Topic { - s.topicsLock.RLock() - defer s.topicsLock.RUnlock() - if t, found := s.topics[topic]; found { - return t - } - return nil - }() - - if existingTopic != nil { - return existingTopic, nil - } - - newTopic, err := s.pubsub.Join(topic) - if err != nil { - return nil, err - } - - s.topicsLock.Lock() - defer s.topicsLock.Unlock() - s.topics[topic] = newTopic - return newTopic, nil -} - -func (s *Service) SubscribeToTopic(topic string) (chan []byte, func(), error) { - t, joinErr := s.joinTopic(topic) - if joinErr != nil { - return nil, nil, joinErr - } - - sub, subErr := t.Subscribe() - if subErr != nil { - return nil, nil, subErr - } - - const bufferSize = 16 - ch := make(chan []byte, bufferSize) - // go func() { - // for { - // msg, err := sub.Next(s.runCtx) - // if err != nil { - // close(ch) - // return - // } - // only forward messages delivered by others - // if msg.ReceivedFrom == s.host.ID() { - // continue - // } - // - // select { - // case ch <- msg.GetData(): - // case <-s.runCtx.Done(): - // } - // } - // }() - return ch, sub.Cancel, nil -} - -func (s *Service) PublishOnTopic(topic string) error { - _, err := s.joinTopic(topic) - return err -} - func (s *Service) SetProtocolHandler(pid protocol.ID, handler func(network.Stream)) { s.host.SetStreamHandler(pid, handler) } diff --git a/p2p/p2p_test.go b/p2p/p2p_test.go index 070a9eedb8..54b19d5900 100644 --- a/p2p/p2p_test.go +++ b/p2p/p2p_test.go @@ -1,141 +1,17 @@ package p2p_test import ( - "context" - "io" - "strings" - "sync" "testing" - "time" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/pebble" "github.com/NethermindEth/juno/p2p" "github.com/NethermindEth/juno/utils" - "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/protocol" - mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" ) -func TestService(t *testing.T) { - t.Skip("TestService") - net, err := mocknet.FullMeshLinked(2) - require.NoError(t, err) - peerHosts := net.Hosts() - require.Len(t, peerHosts, 2) - - timeout := time.Second - testCtx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - peerA, err := p2p.NewWithHost( - peerHosts[0], - "", - false, - nil, - &utils.Integration, - utils.NewNopZapLogger(), - nil, - ) - require.NoError(t, err) - - events, err := peerA.SubscribePeerConnectednessChanged(testCtx) - require.NoError(t, err) - - peerAddrs, err := peerA.ListenAddrs() - require.NoError(t, err) - - peerAddrsString := make([]string, 0, len(peerAddrs)) - for _, addr := range peerAddrs { - peerAddrsString = append(peerAddrsString, addr.String()) - } - - peerB, err := p2p.NewWithHost( - peerHosts[1], - strings.Join(peerAddrsString, ","), - true, - nil, - &utils.Integration, - utils.NewNopZapLogger(), - nil, - ) - require.NoError(t, err) - - wg := sync.WaitGroup{} - wg.Add(2) - go func() { - defer wg.Done() - require.NoError(t, peerA.Run(testCtx)) - }() - go func() { - defer wg.Done() - require.NoError(t, peerB.Run(testCtx)) - }() - - select { - case evt := <-events: - require.Equal(t, network.Connected, evt.Connectedness) - case <-time.After(timeout): - require.True(t, false, "no events were emitted") - } - - t.Run("gossip", func(t *testing.T) { - t.Skip() // todo: flaky test - topic := "coolTopic" - ch, closer, err := peerA.SubscribeToTopic(topic) - require.NoError(t, err) - t.Cleanup(closer) - - maxRetries := 4 - RetryLoop: - for i := 0; i < maxRetries; i++ { - gossipedMessage := []byte(`veryImportantMessage`) - require.NoError(t, peerB.PublishOnTopic(topic)) - - select { - case <-time.After(time.Second): - if i == maxRetries-1 { - require.Fail(t, "timeout: never received the message") - } - case msg := <-ch: - require.Equal(t, gossipedMessage, msg) - break RetryLoop - } - } - }) - - t.Run("protocol handler", func(t *testing.T) { - ch := make(chan []byte) - - superSecretProtocol := protocol.ID("superSecretProtocol") - peerA.SetProtocolHandler(superSecretProtocol, func(stream network.Stream) { - read, err := io.ReadAll(stream) - require.NoError(t, err) - ch <- read - }) - - peerAStream, err := peerB.NewStream(testCtx, superSecretProtocol) - require.NoError(t, err) - - superSecretMessage := []byte(`superSecretMessage`) - _, err = peerAStream.Write(superSecretMessage) - require.NoError(t, err) - require.NoError(t, peerAStream.Close()) - - select { - case <-time.After(timeout): - require.Equal(t, true, false) - case msg := <-ch: - require.Equal(t, superSecretMessage, msg) - } - }) - - cancel() - wg.Wait() -} - func TestInvalidKey(t *testing.T) { _, err := p2p.New( "/ip4/127.0.0.1/tcp/30301", @@ -153,24 +29,6 @@ func TestInvalidKey(t *testing.T) { require.Error(t, err) } -func TestValidKey(t *testing.T) { - t.Skip("TestValidKey") - _, err := p2p.New( - "/ip4/127.0.0.1/tcp/30301", - "", - "peerA", - "", - "08011240333b4a433f16d7ca225c0e99d0d8c437b835cb74a98d9279c561977690c80f681b25ccf3fa45e2f2de260149c112fa516b69057dd3b0151a879416c0cb12d9b3", - false, - nil, - &utils.Integration, - utils.NewNopZapLogger(), - nil, - ) - - require.NoError(t, err) -} - func TestLoadAndPersistPeers(t *testing.T) { testDB := pebble.NewMemTest(t) diff --git a/p2p/starknet/handlers.go b/p2p/starknet/handlers.go index 6a75c29adb..33fb1fbacd 100644 --- a/p2p/starknet/handlers.go +++ b/p2p/starknet/handlers.go @@ -106,23 +106,23 @@ func streamHandler[ReqT proto.Message](ctx context.Context, wg *sync.WaitGroup, } func (h *Handler) HeadersHandler(stream network.Stream) { - streamHandler[*spec.BlockHeadersRequest](h.ctx, &h.wg, stream, h.onHeadersRequest, h.log) + streamHandler(h.ctx, &h.wg, stream, h.onHeadersRequest, h.log) } func (h *Handler) EventsHandler(stream network.Stream) { - streamHandler[*spec.EventsRequest](h.ctx, &h.wg, stream, h.onEventsRequest, h.log) + streamHandler(h.ctx, &h.wg, stream, h.onEventsRequest, h.log) } func (h *Handler) TransactionsHandler(stream network.Stream) { - streamHandler[*spec.TransactionsRequest](h.ctx, &h.wg, stream, h.onTransactionsRequest, h.log) + streamHandler(h.ctx, &h.wg, stream, h.onTransactionsRequest, h.log) } func (h *Handler) ClassesHandler(stream network.Stream) { - streamHandler[*spec.ClassesRequest](h.ctx, &h.wg, stream, h.onClassesRequest, h.log) + streamHandler(h.ctx, &h.wg, stream, h.onClassesRequest, h.log) } func (h *Handler) StateDiffHandler(stream network.Stream) { - streamHandler[*spec.StateDiffsRequest](h.ctx, &h.wg, stream, h.onStateDiffRequest, h.log) + streamHandler(h.ctx, &h.wg, stream, h.onStateDiffRequest, h.log) } func (h *Handler) onHeadersRequest(req *spec.BlockHeadersRequest) (iter.Seq[proto.Message], error) { diff --git a/p2p/starknet/starknet_test.go b/p2p/starknet/starknet_test.go deleted file mode 100644 index 05d3b6dc67..0000000000 --- a/p2p/starknet/starknet_test.go +++ /dev/null @@ -1,909 +0,0 @@ -package starknet_test - -// func nopCloser() error { return nil } -// -// func TestClientHandler(t *testing.T) { //nolint:gocyclo -// mockCtrl := gomock.NewController(t) -// t.Cleanup(mockCtrl.Finish) -// -// testNetwork := utils.Integration -// testCtx, cancel := context.WithCancel(context.Background()) -// t.Cleanup(cancel) -// -// mockNet, err := mocknet.FullMeshConnected(2) -// require.NoError(t, err) -// -// peers := mockNet.Peers() -// require.Len(t, peers, 2) -// handlerID := peers[0] -// clientID := peers[1] -// -// log, err := utils.NewZapLogger(utils.ERROR, false) -// require.NoError(t, err) -// mockReader := mocks.NewMockReader(mockCtrl) -// handler := starknet.NewHandler(mockReader, log) -// -// handlerHost := mockNet.Host(handlerID) -// handlerHost.SetStreamHandler(starknet.CurrentBlockHeaderPID(testNetwork), handler.CurrentBlockHeaderHandler) -// handlerHost.SetStreamHandler(starknet.HeadersPID(&testNetwork), handler.HeadersHandler) -// handlerHost.SetStreamHandler(starknet.BlockBodiesPID(&testNetwork), handler.BlockBodiesHandler) -// handlerHost.SetStreamHandler(starknet.EventsPID(&testNetwork), handler.EventsHandler) -// handlerHost.SetStreamHandler(starknet.ReceiptsPID(&testNetwork), handler.ReceiptsHandler) -// handlerHost.SetStreamHandler(starknet.TransactionsPID(&testNetwork), handler.TransactionsHandler) -// -// clientHost := mockNet.Host(clientID) -// client := starknet.NewClient(func(ctx context.Context, pids ...protocol.ID) (network.Stream, error) { -// return clientHost.NewStream(ctx, handlerID, pids...) -// }, &testNetwork, log) -// -// t.Run("get block headers", func(t *testing.T) { -// type pair struct { -// header *core.Header -// commitments *core.BlockCommitments -// } -// pairsPerBlock := []pair{} -// for i := uint64(0); i < 2; i++ { -// pairsPerBlock = append(pairsPerBlock, pair{ -// header: fillFelts(t, &core.Header{ -// Number: i, -// Timestamp: i, -// TransactionCount: i, -// EventCount: i, -// }), -// commitments: fillFelts(t, &core.BlockCommitments{}), -// }) -// } -// -// for blockNumber, pair := range pairsPerBlock { -// blockNumber := uint64(blockNumber) -// mockReader.EXPECT().BlockHeaderByNumber(blockNumber).Return(pair.header, nil) -// mockReader.EXPECT().BlockCommitmentsByNumber(blockNumber).Return(pair.commitments, nil) -// } -// -// numOfBlocks := uint64(len(pairsPerBlock)) -// res, cErr := client.RequestBlockHeaders(testCtx, &spec.BlockHeadersRequest{ -// Iteration: &spec.Iteration{ -// Start: &spec.Iteration_BlockNumber{ -// BlockNumber: 0, -// }, -// Direction: spec.Iteration_Forward, -// Limit: numOfBlocks, -// Step: 1, -// }, -// }) -// require.NoError(t, cErr) -// -// var count uint64 -// for response, valid := res(); valid; response, valid = res() { -// if count == numOfBlocks { -// assert.True(t, proto.Equal(&spec.Fin{}, response.Part[0].GetFin())) -// count++ -// break -// } -// -// expectedPair := pairsPerBlock[count] -// expectedResponse := expectedHeaderResponse(expectedPair.header, expectedPair.commitments) -// assert.True(t, proto.Equal(expectedResponse, response)) -// -// assert.Equal(t, count, response.Part[0].GetHeader().Number) -// count++ -// } -// -// expectedCount := numOfBlocks + 1 // plus fin -// require.Equal(t, expectedCount, count) -// -// t.Run("get current block header", func(t *testing.T) { -// headerAndCommitments := pairsPerBlock[0] -// mockReader.EXPECT().Height().Return(headerAndCommitments.header.Number, nil) -// mockReader.EXPECT().BlockHeaderByNumber(headerAndCommitments.header.Number).Return(headerAndCommitments.header, nil) -// mockReader.EXPECT().BlockCommitmentsByNumber(headerAndCommitments.header.Number).Return(headerAndCommitments.commitments, nil) -// -// res, cErr := client.RequestCurrentBlockHeader(testCtx, &spec.CurrentBlockHeaderRequest{}) -// require.NoError(t, cErr) -// -// count, numOfBlocks = 0, 1 -// for response, valid := res(); valid; response, valid = res() { -// if count == numOfBlocks { -// assert.True(t, proto.Equal(&spec.Fin{}, response.Part[0].GetFin())) -// count++ -// break -// } -// -// expectedPair := headerAndCommitments -// expectedResponse := expectedHeaderResponse(expectedPair.header, expectedPair.commitments) -// assert.True(t, proto.Equal(expectedResponse, response)) -// -// assert.Equal(t, count, response.Part[0].GetHeader().Number) -// count++ -// } -// expectedCount := numOfBlocks + 1 // plus fin -// require.Equal(t, expectedCount, count) -// }) -// }) -// -// t.Run("get block bodies", func(t *testing.T) { -// /* -// deployedClassHash := utils.HexToFelt(t, "0XCAFEBABE") -// deployedAddress := utils.HexToFelt(t, "0XDEADBEEF") -// replacedClassHash := utils.HexToFelt(t, "0XABCD") -// replacedAddress := utils.HexToFelt(t, "0XABCDE") -// declaredV0ClassAddr := randFelt(t) -// declaredV0ClassHash := randFelt(t) -// storageDiff := core.StorageDiff{ -// Key: randFelt(t), -// Value: randFelt(t), -// } -// const ( -// cairo0Program = "cairo_0_program" -// cairo1Program = "cairo_1_program" -// ) -// cairo1Class := &core.Cairo1Class{ -// Abi: "cairo1 class abi", -// AbiHash: randFelt(t), -// EntryPoints: struct { -// Constructor []core.SierraEntryPoint -// External []core.SierraEntryPoint -// L1Handler []core.SierraEntryPoint -// }{}, -// Program: feltSlice(2), -// ProgramHash: randFelt(t), -// SemanticVersion: "1", -// Compiled: json.RawMessage(cairo1Program), -// } -// -// cairo0Class := &core.Cairo0Class{ -// Abi: json.RawMessage("cairo0 class abi"), -// Program: cairo1Program, -// } -// -// blocks := []struct { -// number uint64 -// stateDiff *core.StateDiff -// }{ -// { -// number: 0, -// stateDiff: &core.StateDiff{ -// StorageDiffs: map[felt.Felt][]core.StorageDiff{ -// *deployedAddress: { -// storageDiff, -// }, -// }, -// Nonces: map[felt.Felt]*felt.Felt{ -// *deployedAddress: randFelt(t), -// *replacedAddress: randFelt(t), -// }, -// DeployedContracts: []core.AddressClassHashPair{ -// { -// Address: deployedAddress, -// ClassHash: deployedClassHash, -// }, -// }, -// DeclaredV0Classes: []*felt.Felt{declaredV0ClassAddr}, -// DeclaredV1Classes: []core.DeclaredV1Class{ -// { -// ClassHash: randFelt(t), -// CompiledClassHash: randFelt(t), -// }, -// }, -// ReplacedClasses: []core.AddressClassHashPair{ -// { -// Address: replacedAddress, -// ClassHash: replacedClassHash, -// }, -// }, -// }, -// }, -// { -// number: 1, -// stateDiff: &core.StateDiff{ // State Diff with a class declared and deployed in the same block -// StorageDiffs: map[felt.Felt][]core.StorageDiff{ -// *deployedAddress: { -// storageDiff, -// }, -// }, -// Nonces: map[felt.Felt]*felt.Felt{ -// *deployedAddress: randFelt(t), -// *replacedAddress: randFelt(t), -// }, -// DeployedContracts: []core.AddressClassHashPair{ -// { -// Address: deployedAddress, -// ClassHash: deployedClassHash, -// }, -// { -// Address: declaredV0ClassAddr, -// ClassHash: declaredV0ClassHash, -// }, -// }, -// DeclaredV0Classes: []*felt.Felt{declaredV0ClassHash}, -// DeclaredV1Classes: []core.DeclaredV1Class{ -// { -// ClassHash: randFelt(t), -// CompiledClassHash: randFelt(t), -// }, -// }, -// ReplacedClasses: []core.AddressClassHashPair{ -// { -// Address: replacedAddress, -// ClassHash: replacedClassHash, -// }, -// }, -// }, -// }, -// } -// limit := uint64(len(blocks)) -// -// for _, block := range blocks { -// mockReader.EXPECT().BlockHeaderByNumber(block.number).Return(&core.Header{ -// Number: block.number, -// }, nil) -// -// mockReader.EXPECT().StateUpdateByNumber(block.number).Return(&core.StateUpdate{ -// StateDiff: block.stateDiff, -// }, nil) -// -// stateHistory := mocks.NewMockStateHistoryReader(mockCtrl) -// v0Class := block.stateDiff.DeclaredV0Classes[0] -// stateHistory.EXPECT().Class(v0Class).Return(&core.DeclaredClass{ -// At: block.number, -// Class: cairo0Class, -// }, nil) -// v1Class := block.stateDiff.DeclaredV1Classes[0] -// stateHistory.EXPECT().Class(v1Class.ClassHash).Return(&core.DeclaredClass{ -// At: block.number, -// Class: cairo1Class, -// }, nil) -// -// stateHistory.EXPECT().ContractClassHash(deployedAddress).Return(deployedClassHash, nil).AnyTimes() -// stateHistory.EXPECT().ContractClassHash(replacedAddress).Return(replacedClassHash, nil).AnyTimes() -// -// mockReader.EXPECT().StateAtBlockNumber(block.number).Return(stateHistory, nopCloser, nil) -// } -// -// res, cErr := client.RequestBlockBodies(testCtx, &spec.BlockBodiesRequest{ -// Iteration: &spec.Iteration{ -// Start: &spec.Iteration_BlockNumber{ -// BlockNumber: blocks[0].number, -// }, -// Direction: spec.Iteration_Forward, -// Limit: limit, -// Step: 1, -// }, -// }) -// require.NoError(t, cErr) -// -// var expectedMessages []*spec.BlockBodiesResponse -// -// for _, b := range blocks { -// expectedMessages = append(expectedMessages, []*spec.BlockBodiesResponse{ -// { -// Id: &spec.BlockID{ -// Number: b.number, -// }, -// BodyMessage: &spec.BlockBodiesResponse_Diff{ -// Diff: &spec.StateDiff{ -// ContractDiffs: []*spec.StateDiff_ContractDiff{ -// { -// Address: core2p2p.AdaptAddress(deployedAddress), -// ClassHash: core2p2p.AdaptFelt(deployedClassHash), -// Nonce: core2p2p.AdaptFelt(b.stateDiff.Nonces[*deployedAddress]), -// Values: []*spec.ContractStoredValue{ -// { -// Key: core2p2p.AdaptFelt(storageDiff.Key), -// Value: core2p2p.AdaptFelt(storageDiff.Value), -// }, -// }, -// }, -// { -// Address: core2p2p.AdaptAddress(replacedAddress), -// ClassHash: core2p2p.AdaptFelt(replacedClassHash), -// Nonce: core2p2p.AdaptFelt(b.stateDiff.Nonces[*replacedAddress]), -// }, -// }, -// ReplacedClasses: utils.Map(b.stateDiff.ReplacedClasses, core2p2p.AdaptAddressClassHashPair), -// DeployedContracts: utils.Map(b.stateDiff.DeployedContracts, core2p2p.AdaptAddressClassHashPair), -// }, -// }, -// }, -// { -// Id: &spec.BlockID{ -// Number: b.number, -// }, -// BodyMessage: &spec.BlockBodiesResponse_Classes{ -// Classes: &spec.Classes{ -// Domain: 0, -// Classes: []*spec.Class{core2p2p.AdaptClass(cairo0Class), core2p2p.AdaptClass(cairo1Class)}, -// }, -// }, -// }, -// { -// Id: &spec.BlockID{ -// Number: b.number, -// }, -// BodyMessage: &spec.BlockBodiesResponse_Proof{ -// Proof: &spec.BlockProof{ -// Proof: nil, -// }, -// }, -// }, -// { -// Id: &spec.BlockID{ -// Number: b.number, -// }, -// BodyMessage: &spec.BlockBodiesResponse_Fin{}, -// }, -// }...) -// } -// -// expectedMessages = append(expectedMessages, &spec.BlockBodiesResponse{ -// Id: nil, -// BodyMessage: &spec.BlockBodiesResponse_Fin{}, -// }) -// -// var count int -// for body, valid := res(); valid; body, valid = res() { -// if bodyProof, ok := body.BodyMessage.(*spec.BlockBodiesResponse_Proof); ok { -// // client generates random slice of bytes in proofs for now -// bodyProof.Proof = nil -// } -// -// if count == 0 || count == 4 { -// diff := body.BodyMessage.(*spec.BlockBodiesResponse_Diff).Diff.ContractDiffs -// sortContractDiff(diff) -// -// expectedDiff := expectedMessages[count].BodyMessage.(*spec.BlockBodiesResponse_Diff).Diff.ContractDiffs -// sortContractDiff(expectedDiff) -// } -// -// if !assert.True(t, proto.Equal(expectedMessages[count], body), "iteration %d, type %T", count, body.BodyMessage) { -// spew.Dump(body.BodyMessage) -// spew.Dump(expectedMessages[count]) -// } -// count++ -// } -// require.Equal(t, len(expectedMessages), count) -// */ -// }) -// -// t.Run("get receipts", func(t *testing.T) { -// txH := randFelt(t) -// // There are common receipt fields shared by all of different transactions. -// commonReceipt := &core.TransactionReceipt{ -// TransactionHash: txH, -// Fee: randFelt(t), -// L2ToL1Message: []*core.L2ToL1Message{fillFelts(t, &core.L2ToL1Message{}), fillFelts(t, &core.L2ToL1Message{})}, -// ExecutionResources: &core.ExecutionResources{ -// BuiltinInstanceCounter: core.BuiltinInstanceCounter{ -// Pedersen: 1, -// RangeCheck: 2, -// Bitwise: 3, -// Output: 4, -// Ecsda: 5, -// EcOp: 6, -// Keccak: 7, -// Poseidon: 8, -// }, -// MemoryHoles: 9, -// Steps: 10, -// }, -// RevertReason: "some revert reason", -// Events: []*core.Event{fillFelts(t, &core.Event{}), fillFelts(t, &core.Event{})}, -// L1ToL2Message: fillFelts(t, &core.L1ToL2Message{}), -// } -// -// specReceiptCommon := &spec.Receipt_Common{ -// TransactionHash: core2p2p.AdaptHash(commonReceipt.TransactionHash), -// ActualFee: core2p2p.AdaptFelt(commonReceipt.Fee), -// MessagesSent: utils.Map(commonReceipt.L2ToL1Message, core2p2p.AdaptMessageToL1), -// ExecutionResources: core2p2p.AdaptExecutionResources(commonReceipt.ExecutionResources), -// RevertReason: commonReceipt.RevertReason, -// } -// -// invokeTx := &core.InvokeTransaction{TransactionHash: txH} -// expectedInvoke := &spec.Receipt{ -// Type: &spec.Receipt_Invoke_{ -// Invoke: &spec.Receipt_Invoke{ -// Common: specReceiptCommon, -// }, -// }, -// } -// -// declareTx := &core.DeclareTransaction{TransactionHash: txH} -// expectedDeclare := &spec.Receipt{ -// Type: &spec.Receipt_Declare_{ -// Declare: &spec.Receipt_Declare{ -// Common: specReceiptCommon, -// }, -// }, -// } -// -// l1Txn := &core.L1HandlerTransaction{ -// TransactionHash: txH, -// CallData: []*felt.Felt{new(felt.Felt).SetBytes([]byte("calldata 1")), new(felt.Felt).SetBytes([]byte("calldata 2"))}, -// ContractAddress: new(felt.Felt).SetBytes([]byte("contract address")), -// EntryPointSelector: new(felt.Felt).SetBytes([]byte("entry point selector")), -// Nonce: new(felt.Felt).SetBytes([]byte("nonce")), -// } -// expectedL1Handler := &spec.Receipt{ -// Type: &spec.Receipt_L1Handler_{ -// L1Handler: &spec.Receipt_L1Handler{ -// Common: specReceiptCommon, -// MsgHash: &spec.Hash{Elements: l1Txn.MessageHash()}, -// }, -// }, -// } -// -// deployAccTxn := &core.DeployAccountTransaction{ -// DeployTransaction: core.DeployTransaction{ -// TransactionHash: txH, -// ContractAddress: new(felt.Felt).SetBytes([]byte("contract address")), -// }, -// } -// expectedDeployAccount := &spec.Receipt{ -// Type: &spec.Receipt_DeployAccount_{ -// DeployAccount: &spec.Receipt_DeployAccount{ -// Common: specReceiptCommon, -// ContractAddress: core2p2p.AdaptFelt(deployAccTxn.ContractAddress), -// }, -// }, -// } -// -// deployTxn := &core.DeployTransaction{ -// TransactionHash: txH, -// ContractAddress: new(felt.Felt).SetBytes([]byte("contract address")), -// } -// expectedDeploy := &spec.Receipt{ -// Type: &spec.Receipt_DeprecatedDeploy{ -// DeprecatedDeploy: &spec.Receipt_Deploy{ -// Common: specReceiptCommon, -// ContractAddress: core2p2p.AdaptFelt(deployTxn.ContractAddress), -// }, -// }, -// } -// -// tests := []struct { -// b *core.Block -// expectedRs *spec.Receipts -// }{ -// { -// b: &core.Block{ -// Header: &core.Header{Number: 0, Hash: randFelt(t)}, -// Transactions: []core.Transaction{invokeTx}, -// Receipts: []*core.TransactionReceipt{commonReceipt}, -// }, -// expectedRs: &spec.Receipts{Items: []*spec.Receipt{expectedInvoke}}, -// }, -// { -// b: &core.Block{ -// Header: &core.Header{Number: 1, Hash: randFelt(t)}, -// Transactions: []core.Transaction{declareTx}, -// Receipts: []*core.TransactionReceipt{commonReceipt}, -// }, -// expectedRs: &spec.Receipts{Items: []*spec.Receipt{expectedDeclare}}, -// }, -// { -// b: &core.Block{ -// Header: &core.Header{Number: 2, Hash: randFelt(t)}, -// Transactions: []core.Transaction{l1Txn}, -// Receipts: []*core.TransactionReceipt{commonReceipt}, -// }, -// expectedRs: &spec.Receipts{Items: []*spec.Receipt{expectedL1Handler}}, -// }, -// { -// b: &core.Block{ -// Header: &core.Header{Number: 3, Hash: randFelt(t)}, -// Transactions: []core.Transaction{deployAccTxn}, -// Receipts: []*core.TransactionReceipt{commonReceipt}, -// }, -// expectedRs: &spec.Receipts{Items: []*spec.Receipt{expectedDeployAccount}}, -// }, -// { -// b: &core.Block{ -// Header: &core.Header{Number: 4, Hash: randFelt(t)}, -// Transactions: []core.Transaction{deployTxn}, -// Receipts: []*core.TransactionReceipt{commonReceipt}, -// }, -// expectedRs: &spec.Receipts{Items: []*spec.Receipt{expectedDeploy}}, -// }, -// { -// // block with multiple txs receipts -// b: &core.Block{ -// Header: &core.Header{Number: 5, Hash: randFelt(t)}, -// Transactions: []core.Transaction{invokeTx, declareTx}, -// Receipts: []*core.TransactionReceipt{commonReceipt, commonReceipt}, -// }, -// expectedRs: &spec.Receipts{Items: []*spec.Receipt{expectedInvoke, expectedDeclare}}, -// }, -// } -// -// numOfBs := uint64(len(tests)) -// for _, test := range tests { -// mockReader.EXPECT().BlockByNumber(test.b.Number).Return(test.b, nil) -// } -// -// res, cErr := client.RequestReceipts(testCtx, &spec.ReceiptsRequest{Iteration: &spec.Iteration{ -// Start: &spec.Iteration_BlockNumber{BlockNumber: tests[0].b.Number}, -// Direction: spec.Iteration_Forward, -// Limit: numOfBs, -// Step: 1, -// }}) -// require.NoError(t, cErr) -// -// var count uint64 -// for receipts, valid := res(); valid; receipts, valid = res() { -// if count == numOfBs { -// assert.NotNil(t, receipts.GetFin()) -// continue -// } -// -// assert.Equal(t, count, receipts.Id.Number) -// -// expectedRs := tests[count].expectedRs -// assert.True(t, proto.Equal(expectedRs, receipts.GetReceipts())) -// count++ -// } -// require.Equal(t, numOfBs, count) -// }) -// -// t.Run("get txns", func(t *testing.T) { -// blocks := []*core.Block{ -// { -// Header: &core.Header{ -// Number: 0, -// }, -// Transactions: []core.Transaction{ -// fillFelts(t, &core.DeployTransaction{ -// ConstructorCallData: feltSlice(3), -// }), -// fillFelts(t, &core.L1HandlerTransaction{ -// CallData: feltSlice(2), -// Version: txVersion(1), -// }), -// }, -// }, -// { -// Header: &core.Header{ -// Number: 1, -// }, -// Transactions: []core.Transaction{ -// fillFelts(t, &core.DeployAccountTransaction{ -// DeployTransaction: core.DeployTransaction{ -// ConstructorCallData: feltSlice(3), -// Version: txVersion(1), -// }, -// TransactionSignature: feltSlice(2), -// }), -// }, -// }, -// { -// Header: &core.Header{ -// Number: 2, -// }, -// Transactions: []core.Transaction{ -// fillFelts(t, &core.DeclareTransaction{ -// TransactionSignature: feltSlice(2), -// Version: txVersion(0), -// }), -// fillFelts(t, &core.DeclareTransaction{ -// TransactionSignature: feltSlice(2), -// Version: txVersion(1), -// }), -// }, -// }, -// { -// Header: &core.Header{ -// Number: 3, -// }, -// Transactions: []core.Transaction{ -// fillFelts(t, &core.InvokeTransaction{ -// CallData: feltSlice(3), -// TransactionSignature: feltSlice(2), -// Version: txVersion(0), -// }), -// fillFelts(t, &core.InvokeTransaction{ -// CallData: feltSlice(4), -// TransactionSignature: feltSlice(2), -// Version: txVersion(1), -// }), -// }, -// }, -// } -// numOfBlocks := uint64(len(blocks)) -// -// for _, block := range blocks { -// mockReader.EXPECT().BlockByNumber(block.Number).Return(block, nil) -// } -// -// res, cErr := client.RequestTransactions(testCtx, &spec.TransactionsRequest{ -// Iteration: &spec.Iteration{ -// Start: &spec.Iteration_BlockNumber{ -// BlockNumber: blocks[0].Number, -// }, -// Direction: spec.Iteration_Forward, -// Limit: numOfBlocks, -// Step: 1, -// }, -// }) -// require.NoError(t, cErr) -// -// var count uint64 -// for txn, valid := res(); valid; txn, valid = res() { -// if count == numOfBlocks { -// assert.NotNil(t, txn.GetFin()) -// break -// } -// -// assert.Equal(t, count, txn.Id.Number) -// -// expectedTx := mapToExpectedTransactions(blocks[count]) -// assert.True(t, proto.Equal(expectedTx, txn.GetTransactions())) -// count++ -// } -// require.Equal(t, numOfBlocks, count) -// }) -// -// t.Run("get events", func(t *testing.T) { -// eventsPerBlock := [][]*core.Event{ -// {}, // block with no events -// { -// { -// From: randFelt(t), -// Data: feltSlice(1), -// Keys: feltSlice(1), -// }, -// }, -// { -// { -// From: randFelt(t), -// Data: feltSlice(2), -// Keys: feltSlice(2), -// }, -// { -// From: randFelt(t), -// Data: feltSlice(3), -// Keys: feltSlice(3), -// }, -// }, -// } -// for blockNumber, events := range eventsPerBlock { -// blockNumber := uint64(blockNumber) -// mockReader.EXPECT().BlockByNumber(blockNumber).Return(&core.Block{ -// Header: &core.Header{ -// Number: blockNumber, -// }, -// Receipts: []*core.TransactionReceipt{ -// { -// TransactionHash: new(felt.Felt).SetUint64(blockNumber), -// Events: events, -// }, -// }, -// }, nil) -// } -// -// numOfBlocks := uint64(len(eventsPerBlock)) -// res, cErr := client.RequestEvents(testCtx, &spec.EventsRequest{ -// Iteration: &spec.Iteration{ -// Start: &spec.Iteration_BlockNumber{ -// BlockNumber: 0, -// }, -// Direction: spec.Iteration_Forward, -// Limit: numOfBlocks, -// Step: 1, -// }, -// }) -// require.NoError(t, cErr) -// -// var count uint64 -// for evnt, valid := res(); valid; evnt, valid = res() { -// if count == numOfBlocks { -// assert.True(t, proto.Equal(&spec.Fin{}, evnt.GetFin())) -// count++ -// break -// } -// -// assert.Equal(t, count, evnt.Id.Number) -// -// passedEvents := eventsPerBlock[int(count)] -// expectedEventsResponse := &spec.EventsResponse_Events{ -// Events: &spec.Events{ -// Items: utils.Map(passedEvents, func(e *core.Event) *spec.Event { -// return core2p2p.AdaptEvent(e, new(felt.Felt).SetUint64(count)) -// }), -// }, -// } -// -// assert.True(t, proto.Equal(expectedEventsResponse.Events, evnt.GetEvents())) -// count++ -// } -// expectedCount := numOfBlocks + 1 // numOfBlocks messages with blocks + 1 fin message -// require.Equal(t, expectedCount, count) -// -// t.Run("block with multiple tx", func(t *testing.T) { -// blockNumber := uint64(0) -// mockReader.EXPECT().BlockByNumber(blockNumber).Return(&core.Block{ -// Header: &core.Header{ -// Number: blockNumber, -// }, -// Receipts: []*core.TransactionReceipt{ -// { -// TransactionHash: new(felt.Felt).SetUint64(0), -// Events: eventsPerBlock[0], -// }, -// { -// TransactionHash: new(felt.Felt).SetUint64(1), -// Events: eventsPerBlock[1], -// }, -// { -// TransactionHash: new(felt.Felt).SetUint64(2), -// Events: eventsPerBlock[2], -// }, -// }, -// }, nil) -// -// res, cErr = client.RequestEvents(testCtx, &spec.EventsRequest{ -// Iteration: &spec.Iteration{ -// Start: &spec.Iteration_BlockNumber{ -// BlockNumber: blockNumber, -// }, -// Direction: spec.Iteration_Forward, -// Limit: 1, -// Step: 1, -// }, -// }) -// -// expectedEventsResponse := &spec.EventsResponse_Events{ -// Events: &spec.Events{ -// Items: []*spec.Event{ -// core2p2p.AdaptEvent(eventsPerBlock[1][0], new(felt.Felt).SetUint64(1)), -// core2p2p.AdaptEvent(eventsPerBlock[2][0], new(felt.Felt).SetUint64(2)), -// core2p2p.AdaptEvent(eventsPerBlock[2][1], new(felt.Felt).SetUint64(2)), -// }, -// }, -// } -// count = 0 -// for evnt, valid := res(); valid; evnt, valid = res() { -// if count == 1 { -// assert.True(t, proto.Equal(&spec.Fin{}, evnt.GetFin())) -// break -// } -// -// assert.Equal(t, count, evnt.Id.Number) -// -// assert.True(t, proto.Equal(expectedEventsResponse.Events, evnt.GetEvents())) -// count++ -// } -// require.NoError(t, cErr) -// }) -// }) -//} -// -// func expectedHeaderResponse(h *core.Header, c *core.BlockCommitments) *spec.BlockHeadersResponse { -// adaptHash := core2p2p.AdaptHash -// return &spec.BlockHeadersResponse{ -// Part: []*spec.BlockHeadersResponsePart{ -// { -// HeaderMessage: &spec.BlockHeadersResponsePart_Header{ -// Header: &spec.BlockHeader{ -// ParentHash: adaptHash(h.ParentHash), -// Number: h.Number, -// Time: timestamppb.New(time.Unix(int64(h.Timestamp), 0)), -// SequencerAddress: core2p2p.AdaptAddress(h.SequencerAddress), -// State: &spec.Patricia{ -// Height: 251, -// Root: adaptHash(h.GlobalStateRoot), -// }, -// Transactions: &spec.Merkle{ -// NLeaves: uint32(h.TransactionCount), -// Root: adaptHash(c.TransactionCommitment), -// }, -// Events: &spec.Merkle{ -// NLeaves: uint32(h.EventCount), -// Root: adaptHash(c.EventCommitment), -// }, -// }, -// }, -// }, -// { -// HeaderMessage: &spec.BlockHeadersResponsePart_Signatures{ -// Signatures: &spec.Signatures{ -// Block: core2p2p.AdaptBlockID(h), -// Signatures: utils.Map(h.Signatures, core2p2p.AdaptSignature), -// }, -// }, -// }, -// }, -// } -//} -// -// func mapToExpectedTransactions(block *core.Block) *spec.Transactions { -// return &spec.Transactions{ -// Items: utils.Map(block.Transactions, core2p2p.AdaptTransaction), -// } -//} -// -// func txVersion(v uint64) *core.TransactionVersion { -// var f felt.Felt -// f.SetUint64(v) -// -// txV := core.TransactionVersion(f) -// return &txV -//} -// -// func feltSlice(n int) []*felt.Felt { -// return make([]*felt.Felt, n) -//} -// -// func randFelt(t *testing.T) *felt.Felt { -// t.Helper() -// -// f, err := new(felt.Felt).SetRandom() -// require.NoError(t, err) -// -// return f -//} -// -// func fillFelts[T any](t *testing.T, i T) T { -// v := reflect.ValueOf(i) -// if v.Kind() == reflect.Ptr && !v.IsNil() { -// v = v.Elem() -// } -// typ := v.Type() -// -// const feltTypeStr = "*felt.Felt" -// -// for i := 0; i < v.NumField(); i++ { -// f := v.Field(i) -// ftyp := typ.Field(i).Type // Get the type of the current field -// -// // Skip unexported fields -// if !f.CanSet() { -// continue -// } -// -// switch f.Kind() { -// case reflect.Ptr: -// // Check if the type is Felt -// if ftyp.String() == feltTypeStr { -// f.Set(reflect.ValueOf(randFelt(t))) -// } else if f.IsNil() { -// // Initialise the pointer if it's nil -// f.Set(reflect.New(ftyp.Elem())) -// } -// -// if f.Elem().Kind() == reflect.Struct { -// // Recursive call for nested structs -// fillFelts(t, f.Interface()) -// } -// case reflect.Slice: -// // For slices, loop and populate -// for j := 0; j < f.Len(); j++ { -// elem := f.Index(j) -// if elem.Type().String() == feltTypeStr { -// elem.Set(reflect.ValueOf(randFelt(t))) -// } -// } -// case reflect.Struct: -// // Recursive call for nested structs -// fillFelts(t, f.Addr().Interface()) -// } -// } -// -// return i -//} -// -// func sortContractDiff(diff []*spec.StateDiff_ContractDiff) { -// sort.Slice(diff, func(i, j int) bool { -// iAddress := diff[i].Address -// jAddress := diff[j].Address -// return bytes.Compare(iAddress.Elements, jAddress.Elements) < 0 -// }) -//} -// -// func noError[T any](t *testing.T, f func() (T, error)) T { -// t.Helper() -// -// v, err := f() -// require.NoError(t, err) -// -// return v -//} diff --git a/p2p/sync.go b/p2p/sync.go index b49af2dc69..47f58936bf 100644 --- a/p2p/sync.go +++ b/p2p/sync.go @@ -670,9 +670,3 @@ func (s *syncService) createIteratorForBlock(blockNumber uint64) *spec.Iteration func (s *syncService) WithListener(l junoSync.EventListener) { s.listener = l } - -//nolint:unused -func (s *syncService) sleep(d time.Duration) { - s.log.Debugw("Sleeping...", "for", d) - time.Sleep(d) -} From 2b1b21977a7df072bebfc5cf22886b871e5cc262 Mon Sep 17 00:00:00 2001 From: aleven1999 Date: Thu, 12 Dec 2024 12:11:28 +0400 Subject: [PATCH 08/10] Remove unused code (#2318) --- adapters/vm2core/vm2core.go | 34 -------------------------------- adapters/vm2core/vm2core_test.go | 32 ------------------------------ 2 files changed, 66 deletions(-) diff --git a/adapters/vm2core/vm2core.go b/adapters/vm2core/vm2core.go index 3646bd98b1..dc505e86d1 100644 --- a/adapters/vm2core/vm2core.go +++ b/adapters/vm2core/vm2core.go @@ -10,29 +10,6 @@ import ( "github.com/ethereum/go-ethereum/common" ) -func AdaptExecutionResources(resources *vm.ExecutionResources) *core.ExecutionResources { - return &core.ExecutionResources{ - BuiltinInstanceCounter: core.BuiltinInstanceCounter{ - Pedersen: resources.Pedersen, - RangeCheck: resources.RangeCheck, - Bitwise: resources.Bitwise, - Ecsda: resources.Ecdsa, - EcOp: resources.EcOp, - Keccak: resources.Keccak, - Poseidon: resources.Poseidon, - SegmentArena: resources.SegmentArena, - Output: resources.Output, - AddMod: resources.AddMod, - MulMod: resources.MulMod, - RangeCheck96: resources.RangeCheck96, - }, - MemoryHoles: resources.MemoryHoles, - Steps: resources.Steps, - DataAvailability: adaptDA(resources.DataAvailability), - TotalGasConsumed: nil, // todo: fill after 0.13.2 - } -} - func AdaptOrderedEvent(event vm.OrderedEvent) *core.Event { return &core.Event{ From: event.From, @@ -62,14 +39,3 @@ func AdaptOrderedEvents(events []vm.OrderedEvent) []*core.Event { }) return utils.Map(events, AdaptOrderedEvent) } - -func adaptDA(da *vm.DataAvailability) *core.DataAvailability { - if da == nil { - return nil - } - - return &core.DataAvailability{ - L1Gas: da.L1Gas, - L1DataGas: da.L1DataGas, - } -} diff --git a/adapters/vm2core/vm2core_test.go b/adapters/vm2core/vm2core_test.go index 168cfa07e7..2a32af8641 100644 --- a/adapters/vm2core/vm2core_test.go +++ b/adapters/vm2core/vm2core_test.go @@ -68,35 +68,3 @@ func TestAdaptOrderedMessagesToL1(t *testing.T) { vm2core.AdaptOrderedMessageToL1(messages[0]), }, vm2core.AdaptOrderedMessagesToL1(messages)) } - -func TestAdaptExecutionResources(t *testing.T) { - require.Equal(t, &core.ExecutionResources{ - BuiltinInstanceCounter: core.BuiltinInstanceCounter{ - Pedersen: 1, - RangeCheck: 2, - Bitwise: 3, - Ecsda: 4, - EcOp: 5, - Keccak: 6, - Poseidon: 7, - SegmentArena: 8, - Output: 11, - }, - MemoryHoles: 9, - Steps: 10, - }, vm2core.AdaptExecutionResources(&vm.ExecutionResources{ - ComputationResources: vm.ComputationResources{ - Pedersen: 1, - RangeCheck: 2, - Bitwise: 3, - Ecdsa: 4, - EcOp: 5, - Keccak: 6, - Poseidon: 7, - SegmentArena: 8, - MemoryHoles: 9, - Steps: 10, - Output: 11, - }, - })) -} From 65b7507fda8a8e0ee1442c9eb044ccb646979804 Mon Sep 17 00:00:00 2001 From: Ng Wei Han <47109095+weiihann@users.noreply.github.com> Date: Thu, 12 Dec 2024 18:20:55 +0800 Subject: [PATCH 09/10] Fix and refactor trie proof logics (#2252) --- core/trie/key.go | 115 ++- core/trie/key_test.go | 156 +++- core/trie/node.go | 53 ++ core/trie/proof.go | 925 ++++++++++----------- core/trie/proof_test.go | 1580 +++++++++++++++--------------------- core/trie/trie.go | 197 +++-- core/trie/trie_pkg_test.go | 32 +- db/pebble/db.go | 2 +- utils/orderedset.go | 67 ++ 9 files changed, 1536 insertions(+), 1591 deletions(-) create mode 100644 utils/orderedset.go diff --git a/core/trie/key.go b/core/trie/key.go index 7f0e6af609..2d94c4ad73 100644 --- a/core/trie/key.go +++ b/core/trie/key.go @@ -3,13 +3,14 @@ package trie import ( "bytes" "encoding/hex" - "errors" "fmt" "math/big" "github.com/NethermindEth/juno/core/felt" ) +var NilKey = &Key{len: 0, bitset: [32]byte{}} + type Key struct { len uint8 bitset [32]byte @@ -24,26 +25,6 @@ func NewKey(length uint8, keyBytes []byte) Key { return k } -func (k *Key) SubKey(n uint8) (*Key, error) { - if n > k.len { - return nil, errors.New(fmt.Sprint("cannot subtract key of length %i from key of length %i", n, k.len)) - } - - newKey := &Key{len: n} - copy(newKey.bitset[:], k.bitset[len(k.bitset)-int((k.len+7)/8):]) //nolint:mnd - - // Shift right by the number of bits that are not needed - shift := k.len - n - for i := len(newKey.bitset) - 1; i >= 0; i-- { - newKey.bitset[i] >>= shift - if i > 0 { - newKey.bitset[i] |= newKey.bitset[i-1] << (8 - shift) - } - } - - return newKey, nil -} - func (k *Key) bytesNeeded() uint { const byteBits = 8 return (uint(k.len) + (byteBits - 1)) / byteBits @@ -96,24 +77,30 @@ func (k *Key) Equal(other *Key) bool { return k.len == other.len && k.bitset == other.bitset } -func (k *Key) Test(bit uint8) bool { +// IsBitSet returns whether the bit at the given position is 1. +// Position 0 represents the least significant (rightmost) bit. +func (k *Key) IsBitSet(position uint8) bool { const LSB = uint8(0x1) - byteIdx := bit / 8 + byteIdx := position / 8 byteAtIdx := k.bitset[len(k.bitset)-int(byteIdx)-1] - bitIdx := bit % 8 + bitIdx := position % 8 return ((byteAtIdx >> bitIdx) & LSB) != 0 } -func (k *Key) String() string { - return fmt.Sprintf("(%d) %s", k.len, hex.EncodeToString(k.bitset[:])) -} - -// DeleteLSB right shifts and shortens the key -func (k *Key) DeleteLSB(n uint8) { +// shiftRight removes n least significant bits from the key by performing a right shift +// operation and reducing the key length. For example, if the key contains bits +// "1111 0000" (length=8) and n=4, the result will be "1111" (length=4). +// +// The operation is destructive - it modifies the key in place. +func (k *Key) shiftRight(n uint8) { if k.len < n { panic("deleting more bits than there are") } + if n == 0 { + return + } + var bigInt big.Int bigInt.SetBytes(k.bitset[:]) bigInt.Rsh(&bigInt, uint(n)) @@ -121,6 +108,17 @@ func (k *Key) DeleteLSB(n uint8) { k.len -= n } +// MostSignificantBits returns a new key with the most significant n bits of the current key. +func (k *Key) MostSignificantBits(n uint8) (*Key, error) { + if n > k.len { + return nil, fmt.Errorf("cannot get more bits than the key length") + } + + keyCopy := k.Copy() + keyCopy.shiftRight(k.len - n) + return &keyCopy, nil +} + // Truncate truncates key to `length` bits by clearing the remaining upper bits func (k *Key) Truncate(length uint8) { k.len = length @@ -136,20 +134,53 @@ func (k *Key) Truncate(length uint8) { } } -func (k *Key) RemoveLastBit() { - if k.len == 0 { - return - } +func (k *Key) String() string { + return fmt.Sprintf("(%d) %s", k.len, hex.EncodeToString(k.bitset[:])) +} - k.len-- +// Copy returns a deep copy of the key +func (k *Key) Copy() Key { + newKey := Key{len: k.len} + copy(newKey.bitset[:], k.bitset[:]) + return newKey +} - unusedBytes := k.unusedBytes() - clear(unusedBytes) +func (k *Key) Bytes() [32]byte { + var result [32]byte + copy(result[:], k.bitset[:]) + return result +} - // clear upper bits on the last used byte - inUseBytes := k.inUseBytes() - unusedBitsCount := 8 - (k.len % 8) - if unusedBitsCount != 8 && len(inUseBytes) > 0 { - inUseBytes[0] = (inUseBytes[0] << unusedBitsCount) >> unusedBitsCount +// findCommonKey finds the set of common MSB bits in two key bitsets. +func findCommonKey(longerKey, shorterKey *Key) (Key, bool) { + divergentBit := findDivergentBit(longerKey, shorterKey) + + if divergentBit == 0 { + return *NilKey, false } + + commonKey := *shorterKey + commonKey.shiftRight(shorterKey.Len() - divergentBit + 1) + return commonKey, divergentBit == shorterKey.Len()+1 +} + +// findDivergentBit finds the first bit that is different between two keys, +// starting from the most significant bit of both keys. +func findDivergentBit(longerKey, shorterKey *Key) uint8 { + divergentBit := uint8(0) + for divergentBit <= shorterKey.Len() && + longerKey.IsBitSet(longerKey.Len()-divergentBit) == shorterKey.IsBitSet(shorterKey.Len()-divergentBit) { + divergentBit++ + } + return divergentBit +} + +func isSubset(longerKey, shorterKey *Key) bool { + divergentBit := findDivergentBit(longerKey, shorterKey) + return divergentBit == shorterKey.Len()+1 +} + +func FeltToKey(length uint8, key *felt.Felt) Key { + keyBytes := key.Bytes() + return NewKey(length, keyBytes[:]) } diff --git a/core/trie/key_test.go b/core/trie/key_test.go index 8d56a31e0c..3867678e6e 100644 --- a/core/trie/key_test.go +++ b/core/trie/key_test.go @@ -68,47 +68,6 @@ func BenchmarkKeyEncoding(b *testing.B) { } } -func TestKeyTest(t *testing.T) { - key := trie.NewKey(44, []byte{0x10, 0x02}) - for i := 0; i < int(key.Len()); i++ { - assert.Equal(t, i == 1 || i == 12, key.Test(uint8(i)), i) - } -} - -func TestDeleteLSB(t *testing.T) { - key := trie.NewKey(16, []byte{0xF3, 0x04}) - - tests := map[string]struct { - shiftAmount uint8 - expectedKey trie.Key - }{ - "delete 0 bits": { - shiftAmount: 0, - expectedKey: key, - }, - "delete 4 bits": { - shiftAmount: 4, - expectedKey: trie.NewKey(12, []byte{0x0F, 0x30}), - }, - "delete 8 bits": { - shiftAmount: 8, - expectedKey: trie.NewKey(8, []byte{0xF3}), - }, - "delete 9 bits": { - shiftAmount: 9, - expectedKey: trie.NewKey(7, []byte{0x79}), - }, - } - - for desc, test := range tests { - t.Run(desc, func(t *testing.T) { - copyKey := key - copyKey.DeleteLSB(test.shiftAmount) - assert.Equal(t, test.expectedKey, copyKey) - }) - } -} - func TestTruncate(t *testing.T) { tests := map[string]struct { key trie.Key @@ -153,3 +112,118 @@ func TestTruncate(t *testing.T) { }) } } + +func TestKeyTest(t *testing.T) { + key := trie.NewKey(44, []byte{0x10, 0x02}) + for i := 0; i < int(key.Len()); i++ { + assert.Equal(t, i == 1 || i == 12, key.IsBitSet(uint8(i)), i) + } +} + +func TestIsBitSet(t *testing.T) { + tests := map[string]struct { + key trie.Key + position uint8 + expected bool + }{ + "single byte, LSB set": { + key: trie.NewKey(8, []byte{0x01}), + position: 0, + expected: true, + }, + "single byte, MSB set": { + key: trie.NewKey(8, []byte{0x80}), + position: 7, + expected: true, + }, + "single byte, middle bit set": { + key: trie.NewKey(8, []byte{0x10}), + position: 4, + expected: true, + }, + "single byte, bit not set": { + key: trie.NewKey(8, []byte{0xFE}), + position: 0, + expected: false, + }, + "multiple bytes, LSB set": { + key: trie.NewKey(16, []byte{0x00, 0x02}), + position: 1, + expected: true, + }, + "multiple bytes, MSB set": { + key: trie.NewKey(16, []byte{0x01, 0x00}), + position: 8, + expected: true, + }, + "multiple bytes, no bits set": { + key: trie.NewKey(16, []byte{0x00, 0x00}), + position: 7, + expected: false, + }, + "check all bits in pattern": { + key: trie.NewKey(8, []byte{0xA5}), // 10100101 + position: 0, + expected: true, + }, + } + + // Additional test for 0xA5 pattern + key := trie.NewKey(8, []byte{0xA5}) // 10100101 + expectedBits := []bool{true, false, true, false, false, true, false, true} + for i, expected := range expectedBits { + assert.Equal(t, expected, key.IsBitSet(uint8(i)), "bit %d in 0xA5", i) + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + result := tc.key.IsBitSet(tc.position) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestMostSignificantBits(t *testing.T) { + tests := []struct { + name string + key trie.Key + n uint8 + want trie.Key + expectErr bool + }{ + { + name: "Valid case", + key: trie.NewKey(8, []byte{0b11110000}), + n: 4, + want: trie.NewKey(4, []byte{0b00001111}), + expectErr: false, + }, + { + name: "Request more bits than available", + key: trie.NewKey(8, []byte{0b11110000}), + n: 10, + want: trie.Key{}, + expectErr: true, + }, + { + name: "Zero bits requested", + key: trie.NewKey(8, []byte{0b11110000}), + n: 0, + want: trie.NewKey(0, []byte{}), + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.key.MostSignificantBits(tt.n) + if (err != nil) != tt.expectErr { + t.Errorf("MostSignificantBits() error = %v, expectErr %v", err, tt.expectErr) + return + } + if !tt.expectErr && !got.Equal(&tt.want) { + t.Errorf("MostSignificantBits() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/core/trie/node.go b/core/trie/node.go index db9cb85206..2ef176f92a 100644 --- a/core/trie/node.go +++ b/core/trie/node.go @@ -3,6 +3,7 @@ package trie import ( "bytes" "errors" + "fmt" "github.com/NethermindEth/juno/core/felt" ) @@ -138,3 +139,55 @@ func (n *Node) UnmarshalBinary(data []byte) error { n.RightHash.SetBytes(data[:felt.Bytes]) return nil } + +func (n *Node) String() string { + return fmt.Sprintf("Node{Value: %s, Left: %s, Right: %s, LeftHash: %s, RightHash: %s}", n.Value, n.Left, n.Right, n.LeftHash, n.RightHash) +} + +// Update the receiver with non-nil fields from the `other` Node. +// If a field is non-nil in both Nodes, they must be equal, or an error is returned. +// +// This method modifies the receiver in-place and returns an error if any field conflicts are detected. +// +//nolint:gocyclo +func (n *Node) Update(other *Node) error { + // First validate all fields for conflicts + if n.Value != nil && other.Value != nil && !n.Value.Equal(other.Value) { + return fmt.Errorf("conflicting Values: %v != %v", n.Value, other.Value) + } + + if n.Left != nil && other.Left != nil && !n.Left.Equal(NilKey) && !other.Left.Equal(NilKey) && !n.Left.Equal(other.Left) { + return fmt.Errorf("conflicting Left keys: %v != %v", n.Left, other.Left) + } + + if n.Right != nil && other.Right != nil && !n.Right.Equal(NilKey) && !other.Right.Equal(NilKey) && !n.Right.Equal(other.Right) { + return fmt.Errorf("conflicting Right keys: %v != %v", n.Right, other.Right) + } + + if n.LeftHash != nil && other.LeftHash != nil && !n.LeftHash.Equal(other.LeftHash) { + return fmt.Errorf("conflicting LeftHash: %v != %v", n.LeftHash, other.LeftHash) + } + + if n.RightHash != nil && other.RightHash != nil && !n.RightHash.Equal(other.RightHash) { + return fmt.Errorf("conflicting RightHash: %v != %v", n.RightHash, other.RightHash) + } + + // After validation, perform all updates + if other.Value != nil { + n.Value = other.Value + } + if other.Left != nil && !other.Left.Equal(NilKey) { + n.Left = other.Left + } + if other.Right != nil && !other.Right.Equal(NilKey) { + n.Right = other.Right + } + if other.LeftHash != nil { + n.LeftHash = other.LeftHash + } + if other.RightHash != nil { + n.RightHash = other.RightHash + } + + return nil +} diff --git a/core/trie/proof.go b/core/trie/proof.go index 6dcbe3960c..bc4b66d0d9 100644 --- a/core/trie/proof.go +++ b/core/trie/proof.go @@ -4,18 +4,21 @@ import ( "errors" "fmt" + "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/utils" ) -var ( - ErrUnknownProofNode = errors.New("unknown proof node") - ErrChildHashNotFound = errors.New("can't determine the child hash from the parent and child") -) +type ProofNodeSet = utils.OrderedSet[felt.Felt, ProofNode] + +func NewProofNodeSet() *ProofNodeSet { + return utils.NewOrderedSet[felt.Felt, ProofNode]() +} type ProofNode interface { Hash(hash hashFunc) *felt.Felt Len() uint8 - PrettyPrint() + String() string } type Binary struct { @@ -31,10 +34,8 @@ func (b *Binary) Len() uint8 { return 1 } -func (b *Binary) PrettyPrint() { - fmt.Printf(" Binary:\n") - fmt.Printf(" LeftHash: %v\n", b.LeftHash) - fmt.Printf(" RightHash: %v\n", b.RightHash) +func (b *Binary) String() string { + return fmt.Sprintf("Binary: %v:\n\tLeftHash: %v\n\tRightHash: %v\n", b.Hash(crypto.Pedersen), b.LeftHash, b.RightHash) } type Edge struct { @@ -54,623 +55,585 @@ func (e *Edge) Len() uint8 { return e.Path.Len() } -func (e *Edge) PrettyPrint() { - fmt.Printf(" Edge:\n") - fmt.Printf(" Child: %v\n", e.Child) - fmt.Printf(" Path: %v\n", e.Path) +func (e *Edge) String() string { + return fmt.Sprintf("Edge: %v:\n\tChild: %v\n\tPath: %v\n", e.Hash(crypto.Pedersen), e.Child, e.Path) } -func GetBoundaryProofs(leftBoundary, rightBoundary *Key, tri *Trie) ([2][]ProofNode, error) { - proofs := [2][]ProofNode{} - leftProof, err := GetProof(leftBoundary, tri) - if err != nil { - return proofs, err - } - rightProof, err := GetProof(rightBoundary, tri) +// Prove generates a Merkle proof for a given key in the trie. +// The result contains the proof nodes on the path from the root to the leaf. +// The value is included in the proof if the key is present in the trie. +// If the key is not present, the proof will contain the nodes on the path to the closest ancestor. +func (t *Trie) Prove(key *felt.Felt, proof *ProofNodeSet) error { + k := t.FeltToKey(key) + + nodesFromRoot, err := t.nodesFromRoot(&k) if err != nil { - return proofs, err + return err } - proofs[0] = leftProof - proofs[1] = rightProof - return proofs, nil -} -func isEdge(parentKey *Key, sNode StorageNode) bool { - sNodeLen := sNode.key.len - if parentKey == nil { // Root - return sNodeLen != 0 - } - return sNodeLen-parentKey.len > 1 -} + var parentKey *Key -// Note: we need to account for the fact that Junos Trie has nodes that are Binary AND Edge, -// whereas the protocol requires nodes that are Binary XOR Edge -func transformNode(tri *Trie, parentKey *Key, sNode StorageNode) (*Edge, *Binary, error) { - isEdgeBool := isEdge(parentKey, sNode) + for i, sNode := range nodesFromRoot { + sNodeEdge, sNodeBinary, err := storageNodeToProofNode(t, parentKey, sNode) + if err != nil { + return err + } + isLeaf := sNode.key.len == t.height - var edge *Edge - if isEdgeBool { - edgePath := path(sNode.key, parentKey) - edge = &Edge{ - Path: &edgePath, - Child: sNode.node.Value, + if sNodeEdge != nil && !isLeaf { // Internal Edge + proof.Put(*sNodeEdge.Hash(t.hash), sNodeEdge) + proof.Put(*sNodeBinary.Hash(t.hash), sNodeBinary) + } else if sNodeEdge == nil && !isLeaf { // Internal Binary + proof.Put(*sNodeBinary.Hash(t.hash), sNodeBinary) + } else if sNodeEdge != nil && isLeaf { // Leaf Edge + proof.Put(*sNodeEdge.Hash(t.hash), sNodeEdge) + } else if sNodeEdge == nil && sNodeBinary == nil { // sNode is a binary leaf + break } + parentKey = nodesFromRoot[i].key } - if sNode.key.len == tri.height { // Leaf - return edge, nil, nil - } - lNode, err := tri.GetNodeFromKey(sNode.node.Left) - if err != nil { - return nil, nil, err - } - rNode, err := tri.GetNodeFromKey(sNode.node.Right) + return nil +} + +// GetRangeProof generates a range proof for the given range of keys. +// The proof contains the proof nodes on the path from the root to the closest ancestor of the left and right keys. +func (t *Trie) GetRangeProof(leftKey, rightKey *felt.Felt, proofSet *ProofNodeSet) error { + err := t.Prove(leftKey, proofSet) if err != nil { - return nil, nil, err + return err } - rightHash := rNode.Value - if isEdge(sNode.key, StorageNode{node: rNode, key: sNode.node.Right}) { - edgePath := path(sNode.node.Right, sNode.key) - rEdge := &Edge{ - Path: &edgePath, - Child: rNode.Value, - } - rightHash = rEdge.Hash(tri.hash) - } - leftHash := lNode.Value - if isEdge(sNode.key, StorageNode{node: lNode, key: sNode.node.Left}) { - edgePath := path(sNode.node.Left, sNode.key) - lEdge := &Edge{ - Path: &edgePath, - Child: lNode.Value, - } - leftHash = lEdge.Hash(tri.hash) + // If they are the same key, don't need to generate the proof again + if leftKey.Equal(rightKey) { + return nil } - binary := &Binary{ - LeftHash: leftHash, - RightHash: rightHash, - } - - return edge, binary, nil -} -// pathSplitOccurredCheck checks if there happens at most one split in the merged path -// loops through the merged paths if left and right hashes of a node exist in the nodeHashes -// then a split happened in case of multiple splits it returns an error -func pathSplitOccurredCheck(mergedPath []ProofNode, nodeHashes map[felt.Felt]ProofNode) error { - splitHappened := false - for _, node := range mergedPath { - switch node := node.(type) { - case *Edge: - continue - case *Binary: - _, leftExists := nodeHashes[*node.LeftHash] - _, rightExists := nodeHashes[*node.RightHash] - if leftExists && rightExists { - if splitHappened { - return errors.New("split happened more than once") - } - splitHappened = true - } - default: - return fmt.Errorf("%w: %T", ErrUnknownProofNode, node) - } + err = t.Prove(rightKey, proofSet) + if err != nil { + return err } + return nil } -func rootNodeExistsCheck(rootHash *felt.Felt, nodeHashes map[felt.Felt]ProofNode) (ProofNode, error) { - currNode, rootExists := nodeHashes[*rootHash] - if !rootExists { - return currNode, errors.New("root hash not found in the merged path") - } +// VerifyProof verifies that a proof path is valid for a given key in a binary trie. +// It walks through the proof nodes, verifying each step matches the expected path to reach the key. +// +// The verification process: +// 1. Starts at the root hash and retrieves the corresponding proof node +// 2. For each proof node: +// - Verifies the node's computed hash matches the expected hash +// - For Binary nodes: +// -- Uses the next unprocessed bit in the key to choose left/right path +// -- If key bit is 0, takes left path; if 1, takes right path +// - For Edge nodes: +// -- Verifies the compressed path matches the corresponding bits in the key +// -- Moves to the child node if paths match +// +// 3. Continues until all bits in the key are processed +// +// The proof is considered invalid if: +// - Any proof node is missing from the OrderedSet +// - Any node's computed hash doesn't match its expected hash +// - The path bits don't match the key bits +// - The proof ends before processing all key bits +func VerifyProof(root, keyFelt *felt.Felt, proof *ProofNodeSet, hash hashFunc) (*felt.Felt, error) { + key := FeltToKey(globalTrieHeight, keyFelt) + expectedHash := root + keyLen := key.Len() - return currNode, nil -} + var curPos uint8 + for { + proofNode, ok := proof.Get(*expectedHash) + if !ok { + return nil, fmt.Errorf("proof node not found, expected hash: %s", expectedHash.String()) + } -// traverseNodes traverses the merged proof path starting at `currNode` -// and adds nodes to `path` slice. It stops when the split node is added -// or the path is exhausted, and `currNode` children are not included -// in the path (nodeHashes) -func traverseNodes(currNode ProofNode, path *[]ProofNode, nodeHashes map[felt.Felt]ProofNode) { - *path = append(*path, currNode) + // Verify the hash matches + if !proofNode.Hash(hash).Equal(expectedHash) { + return nil, fmt.Errorf("proof node hash mismatch, expected hash: %s, got hash: %s", expectedHash.String(), proofNode.Hash(hash).String()) + } - switch currNode := currNode.(type) { - case *Binary: - nodeLeft, leftExist := nodeHashes[*currNode.LeftHash] - nodeRight, rightExist := nodeHashes[*currNode.RightHash] + switch node := proofNode.(type) { + case *Binary: // Binary nodes represent left/right choices + if key.Len() <= curPos { + return nil, fmt.Errorf("key length less than current position, key length: %d, current position: %d", key.Len(), curPos) + } + // Determine the next node to traverse based on the next bit position + expectedHash = node.LeftHash + if key.IsBitSet(keyLen - curPos - 1) { + expectedHash = node.RightHash + } + curPos++ + case *Edge: // Edge nodes represent paths between binary nodes + if !verifyEdgePath(&key, node.Path, curPos) { + return &felt.Zero, nil + } - if leftExist && rightExist { - return - } else if leftExist { - traverseNodes(nodeLeft, path, nodeHashes) - } else if rightExist { - traverseNodes(nodeRight, path, nodeHashes) + // Move to the immediate child node + curPos += node.Path.Len() + expectedHash = node.Child } - case *Edge: - edgeNode, exist := nodeHashes[*currNode.Child] - if exist { - traverseNodes(edgeNode, path, nodeHashes) + + // We've consumed all bits in our path + if curPos >= keyLen { + return expectedHash, nil } } } -// MergeProofPaths removes duplicates and merges proof paths into a single path -// merges paths in the specified order [commonNodes..., leftNodes..., rightNodes...] -// ordering of the merged path is not important -// since SplitProofPath can discover the left and right paths using the merged path and the rootHash -func MergeProofPaths(leftPath, rightPath []ProofNode, hash hashFunc) ([]ProofNode, *felt.Felt, error) { - merged := []ProofNode{} - minLen := min(len(leftPath), len(rightPath)) - - if len(leftPath) == 0 || len(rightPath) == 0 { - return merged, nil, errors.New("empty proof paths") - } - - if !leftPath[0].Hash(hash).Equal(rightPath[0].Hash(hash)) { - return merged, nil, errors.New("roots of the proof paths are different") +// VerifyRangeProof checks the validity of given key-value pairs and range proof against a provided root hash. +// The key-value pairs should be consecutive (no gaps) and monotonically increasing. +// The range proof contains two edge proofs: one for the first key and another for the last key. +// Both edge proofs can be for existent or non-existent keys. +// This function handles the following special cases: +// +// - All elements proof: The proof can be nil if the range includes all leaves in the trie. +// - Single element proof: Both left and right edge proofs are identical, and the range contains only one element. +// - Zero element proof: A single edge proof suffices for verification. The proof is invalid if there are additional elements. +// +// The function returns a boolean indicating if there are more elements and an error if the range proof is invalid. +// +// TODO(weiihann): Given a binary leaf and a left-sibling first key, if the right sibling is removed, the proof would still be valid. +// Conversely, given a binary leaf and a right-sibling last key, if the left sibling is removed, the proof would still be valid. +// Range proof should not be valid for both of these cases, but currently is, which is an attack vector. +// The problem probably lies in how we do root hash calculation. +func VerifyRangeProof(root, first *felt.Felt, keys, values []*felt.Felt, proof *ProofNodeSet) (bool, error) { //nolint:funlen,gocyclo + // Ensure the number of keys and values are the same + if len(keys) != len(values) { + return false, fmt.Errorf("inconsistent length of proof data, keys: %d, values: %d", len(keys), len(values)) } - rootHash := leftPath[0].Hash(hash) - - // Get duplicates and insert by one - i := 0 - for i = 0; i < minLen; i++ { - leftNode := leftPath[i] - rightNode := rightPath[i] + // Ensure all keys are monotonically increasing and values contain no deletions + for i := 0; i < len(keys); i++ { + if i < len(keys)-1 && keys[i].Cmp(keys[i+1]) > 0 { + return false, errors.New("keys are not monotonic increasing") + } - if leftNode.Hash(hash).Equal(rightNode.Hash(hash)) { - merged = append(merged, leftNode) - } else { - break + if values[i] == nil || values[i].Equal(&felt.Zero) { + return false, errors.New("range contains empty leaf") } } - // Add rest of the nodes - merged = append(merged, leftPath[i:]...) - merged = append(merged, rightPath[i:]...) - - return merged, rootHash, nil -} - -// SplitProofPath splits the merged proof path into two paths (left and right), which were merged before -// it first validates that the merged path is not circular, the split happens at most once and rootHash exists -// then calls traverseNodes to split the path to left and right paths -func SplitProofPath(mergedPath []ProofNode, rootHash *felt.Felt, hash hashFunc) ([]ProofNode, []ProofNode, error) { - commonPath := []ProofNode{} - leftPath := []ProofNode{} - rightPath := []ProofNode{} - nodeHashes := make(map[felt.Felt]ProofNode) - - for _, node := range mergedPath { - nodeHash := node.Hash(hash) - _, nodeExists := nodeHashes[*nodeHash] - - if nodeExists { - return leftPath, rightPath, errors.New("duplicate node in the merged path") + // Special case: no edge proof provided; the given range contains all leaves in the trie + if proof == nil { + tr, err := buildTrie(globalTrieHeight, nil, nil, keys, values) + if err != nil { + return false, err } - nodeHashes[*nodeHash] = node - } - if len(mergedPath) == 0 { - return leftPath, rightPath, nil - } + recomputedRoot, err := tr.Root() + if err != nil { + return false, err + } - currNode, err := rootNodeExistsCheck(rootHash, nodeHashes) - if err != nil { - return leftPath, rightPath, err - } + if !recomputedRoot.Equal(root) { + return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", root.String(), recomputedRoot.String()) + } - if err := pathSplitOccurredCheck(mergedPath, nodeHashes); err != nil { - return leftPath, rightPath, err + return false, nil // no more elements available } - traverseNodes(currNode, &commonPath, nodeHashes) - - leftPath = append(leftPath, commonPath...) - rightPath = append(rightPath, commonPath...) + nodes := NewStorageNodeSet() + firstKey := FeltToKey(globalTrieHeight, first) - currNode = commonPath[len(commonPath)-1] - - leftNode := nodeHashes[*currNode.(*Binary).LeftHash] - rightNode := nodeHashes[*currNode.(*Binary).RightHash] - - traverseNodes(leftNode, &leftPath, nodeHashes) - traverseNodes(rightNode, &rightPath, nodeHashes) + // Special case: there is a provided proof but no key-value pairs, make sure regenerated trie has no more values + // Empty range proof with more elements on the right is not accepted in this function. + // This is due to snap sync specification detail, where the responder must send an existing key (if any) if the requested range is empty. + if len(keys) == 0 { + rootKey, val, err := proofToPath(root, &firstKey, proof, nodes) + if err != nil { + return false, err + } - return leftPath, rightPath, nil -} + if val != nil || hasRightElement(rootKey, &firstKey, nodes) { + return false, errors.New("more entries available") + } -// https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L514 -// GetProof generates a set of proof nodes from the root to the leaf. -// The proof never contains the leaf node if it is set, as we already know it's hash. -func GetProof(key *Key, tri *Trie) ([]ProofNode, error) { - nodesFromRoot, err := tri.nodesFromRoot(key) - if err != nil { - return nil, err + return false, nil } - proofNodes := []ProofNode{} - var parentKey *Key + last := keys[len(keys)-1] + lastKey := FeltToKey(globalTrieHeight, last) - for i, sNode := range nodesFromRoot { - sNodeEdge, sNodeBinary, err := transformNode(tri, parentKey, sNode) + // Special case: there is only one element and two edge keys are the same + if len(keys) == 1 && firstKey.Equal(&lastKey) { + rootKey, val, err := proofToPath(root, &firstKey, proof, nodes) if err != nil { - return nil, err + return false, err } - isLeaf := sNode.key.len == tri.height - if sNodeEdge != nil && !isLeaf { // Internal Edge - proofNodes = append(proofNodes, sNodeEdge, sNodeBinary) - } else if sNodeEdge == nil && !isLeaf { // Internal Binary - proofNodes = append(proofNodes, sNodeBinary) - } else if sNodeEdge != nil && isLeaf { // Leaf Edge - proofNodes = append(proofNodes, sNodeEdge) - } else if sNodeEdge == nil && sNodeBinary == nil { // sNode is a binary leaf - break + elementKey := FeltToKey(globalTrieHeight, keys[0]) + if !firstKey.Equal(&elementKey) { + return false, errors.New("correct proof but invalid key") } - parentKey = nodesFromRoot[i].key - } - return proofNodes, nil -} -// VerifyProof checks if `leafPath` leads from `root` to `leafHash` along the `proofNodes` -// https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2006 -func VerifyProof(root *felt.Felt, key *Key, value *felt.Felt, proofs []ProofNode, hash hashFunc) bool { - expectedHash := root - remainingPath := NewKey(key.len, key.bitset[:]) - for i, proofNode := range proofs { - if !proofNode.Hash(hash).Equal(expectedHash) { - return false + if val == nil || !values[0].Equal(val) { + return false, errors.New("correct proof but invalid value") } - switch proofNode := proofNode.(type) { - case *Binary: - if remainingPath.Test(remainingPath.Len() - 1) { - expectedHash = proofNode.RightHash - } else { - expectedHash = proofNode.LeftHash - } - remainingPath.RemoveLastBit() - case *Edge: - subKey, err := remainingPath.SubKey(proofNode.Path.Len()) - if err != nil { - return false - } - - // Todo: - // If we are verifying the key doesn't exist, then we should - // update subKey to point in the other direction - if value == nil && i == len(proofs)-1 { - return true - } - - if !proofNode.Path.Equal(subKey) { - return false - } - expectedHash = proofNode.Child - remainingPath.Truncate(251 - proofNode.Path.Len()) //nolint:mnd - } + return hasRightElement(rootKey, &firstKey, nodes), nil } - return expectedHash.Equal(value) -} - -// VerifyRangeProof verifies the range proof for the given range of keys. -// This is achieved by constructing a trie from the boundary proofs, and the supplied key-values. -// If the root of the reconstructed trie matches the supplied root, then the verification passes. -// If the trie is constructed incorrectly then the root will have an incorrect key(len,path), and value, -// and therefore it's hash won't match the expected root. -// ref: https://github.com/ethereum/go-ethereum/blob/v1.14.3/trie/proof.go#L484 -func VerifyRangeProof(root *felt.Felt, keys, values []*felt.Felt, proofKeys [2]*Key, proofValues [2]*felt.Felt, - proofs [2][]ProofNode, hash hashFunc, -) (bool, error) { - // Step 0: checks - if len(keys) != len(values) { - return false, fmt.Errorf("inconsistent proof data, number of keys: %d, number of values: %d", len(keys), len(values)) + // In all other cases, we require two edge paths available. + // First, ensure that the last key is greater than the first key + if last.Cmp(first) <= 0 { + return false, errors.New("last key is less than first key") } - // Ensure all keys are monotonic increasing - if err := ensureMonotonicIncreasing(proofKeys, keys); err != nil { + rootKey, _, err := proofToPath(root, &firstKey, proof, nodes) + if err != nil { return false, err } - // Ensure the inner values contain no deletions - for _, value := range values { - if value.Equal(&felt.Zero) { - return false, errors.New("range contains deletion") - } + lastRootKey, _, err := proofToPath(root, &lastKey, proof, nodes) + if err != nil { + return false, err } - // Step 1: Verify proofs, and get proof paths - var proofPaths [2][]StorageNode - var err error - for i := 0; i < 2; i++ { - if proofs[i] != nil { - if !VerifyProof(root, proofKeys[i], proofValues[i], proofs[i], hash) { - return false, fmt.Errorf("invalid proof for key %x", proofKeys[i].String()) - } - - proofPaths[i], err = ProofToPath(proofs[i], proofKeys[i], hash) - if err != nil { - return false, err - } - } + if !rootKey.Equal(lastRootKey) { + return false, errors.New("first and last root keys do not match") } - // Step 2: Build trie from proofPaths and keys - tmpTrie, err := BuildTrie(proofPaths[0], proofPaths[1], keys, values) + // Build the trie from the proof paths + tr, err := buildTrie(globalTrieHeight, rootKey, nodes.List(), keys, values) if err != nil { return false, err } // Verify that the recomputed root hash matches the provided root hash - recomputedRoot, err := tmpTrie.Root() + recomputedRoot, err := tr.Root() if err != nil { return false, err } + if !recomputedRoot.Equal(root) { - return false, errors.New("root hash mismatch") + return false, fmt.Errorf("root hash mismatch, expected: %s, got: %s", root.String(), recomputedRoot.String()) } - return true, nil + return hasRightElement(rootKey, &lastKey, nodes), nil } -func ensureMonotonicIncreasing(proofKeys [2]*Key, keys []*felt.Felt) error { - if proofKeys[0] != nil { - leftProofFelt := proofKeys[0].Felt() - if leftProofFelt.Cmp(keys[0]) >= 0 { - return errors.New("range is not monotonically increasing") - } - } - if proofKeys[1] != nil { - rightProofFelt := proofKeys[1].Felt() - if keys[len(keys)-1].Cmp(&rightProofFelt) >= 0 { - return errors.New("range is not monotonically increasing") - } - } - if len(keys) >= 2 { - for i := 0; i < len(keys)-1; i++ { - if keys[i].Cmp(keys[i+1]) >= 0 { - return errors.New("range is not monotonically increasing") - } - } +// isEdge checks if the storage node is an edge node. +func isEdge(parentKey *Key, sNode StorageNode) bool { + sNodeLen := sNode.key.len + if parentKey == nil { // Root + return sNodeLen != 0 } - return nil + return sNodeLen-parentKey.len > 1 } -// compressNode determines if the node needs compressed, and if so, the len needed to arrive at the next key -func compressNode(idx int, proofNodes []ProofNode, hashF hashFunc) (int, uint8, error) { - parent := proofNodes[idx] - - if idx == len(proofNodes)-1 { - if _, ok := parent.(*Edge); ok { - return 1, parent.Len(), nil +// storageNodeToProofNode converts a StorageNode to the ProofNode(s). +// Juno's Trie has nodes that are Binary AND Edge, whereas the protocol requires nodes that are Binary XOR Edge. +// We need to convert the former to the latter for proof generation. +func storageNodeToProofNode(tri *Trie, parentKey *Key, sNode StorageNode) (*Edge, *Binary, error) { + var edge *Edge + if isEdge(parentKey, sNode) { + edgePath := path(sNode.key, parentKey) + edge = &Edge{ + Path: &edgePath, + Child: sNode.node.Value, } - return 0, parent.Len(), nil + } + if sNode.key.len == tri.height { // Leaf + return edge, nil, nil + } + lNode, err := tri.GetNodeFromKey(sNode.node.Left) + if err != nil { + return nil, nil, err + } + rNode, err := tri.GetNodeFromKey(sNode.node.Right) + if err != nil { + return nil, nil, err } - child := proofNodes[idx+1] - _, isChildBinary := child.(*Binary) - isChildEdge := !isChildBinary - switch parent := parent.(type) { - case *Edge: - if isChildEdge { - break - } - return 1, parent.Len(), nil - case *Binary: - if isChildBinary { - break + rightHash := rNode.Value + if isEdge(sNode.key, StorageNode{node: rNode, key: sNode.node.Right}) { + edgePath := path(sNode.node.Right, sNode.key) + rEdge := &Edge{ + Path: &edgePath, + Child: rNode.Value, } - childHash := child.Hash(hashF) - if parent.LeftHash.Equal(childHash) || parent.RightHash.Equal(childHash) { - return 1, child.Len(), nil + rightHash = rEdge.Hash(tri.hash) + } + leftHash := lNode.Value + if isEdge(sNode.key, StorageNode{node: lNode, key: sNode.node.Left}) { + edgePath := path(sNode.node.Left, sNode.key) + lEdge := &Edge{ + Path: &edgePath, + Child: lNode.Value, } - return 0, 0, ErrChildHashNotFound + leftHash = lEdge.Hash(tri.hash) + } + binary := &Binary{ + LeftHash: leftHash, + RightHash: rightHash, } - return 0, 1, nil + return edge, binary, nil } -func assignChild(i, compressedParent int, parentNode *Node, - nilKey, leafKey, parentKey *Key, proofNodes []ProofNode, hashF hashFunc, -) (*Key, error) { - childInd := i + compressedParent + 1 - childKey, err := getChildKey(childInd, parentKey, leafKey, nilKey, proofNodes, hashF) +// proofToPath converts a Merkle proof to trie node path. All necessary nodes will be resolved and leave the remaining +// as hashes. The given edge proof can be existent or non-existent. +func proofToPath(root *felt.Felt, key *Key, proof *ProofNodeSet, nodes *StorageNodeSet) (*Key, *felt.Felt, error) { + rootKey, val, err := buildPath(root, key, 0, nil, proof, nodes) if err != nil { - return nil, err - } - if leafKey.Test(leafKey.len - parentKey.len - 1) { - parentNode.Right = childKey - parentNode.Left = nilKey - } else { - parentNode.Right = nilKey - parentNode.Left = childKey + return nil, nil, err } - return childKey, nil -} -// ProofToPath returns a set of storage nodes from the root to the end of the proof path. -// The storage nodes will have the hashes of the children, but only the key of the child -// along the path outlined by the proof. -func ProofToPath(proofNodes []ProofNode, leafKey *Key, hashF hashFunc) ([]StorageNode, error) { - pathNodes := []StorageNode{} - - // Child keys that can't be derived are set to nilKey, so that we can store the node - zeroFeltBytes := new(felt.Felt).Bytes() - nilKey := NewKey(0, zeroFeltBytes[:]) - - for i, pNode := range proofNodes { - // Keep moving along the path (may need to skip nodes that were compressed into the last path node) - if i != 0 { - if skipNode(pNode, pathNodes, hashF) { - continue - } + // Special case: non-existent key at the root + // We must include the root node in the node set. + // We will only get the following two cases: + // 1. The root node is an edge node only where path.len == key.len (single key trie) + // 2. The root node is an edge node + binary node (double key trie) + if nodes.Size() == 0 { + proofNode, ok := proof.Get(*root) + if !ok { + return nil, nil, fmt.Errorf("root proof node not found: %s", root) } - var parentKey *Key - parentNode := Node{} - - // Set the key of the current node - compressParent, compressParentOffset, err := compressNode(i, proofNodes, hashF) - if err != nil { - return nil, err - } - parentKey, err = getParentKey(i, compressParentOffset, leafKey, pNode, pathNodes, proofNodes) - if err != nil { - return nil, err + edge, ok := proofNode.(*Edge) + if !ok { + return nil, nil, fmt.Errorf("expected edge node at root, got: %T", proofNode) } - // Don't store leafs along proof paths - if parentKey.len == 251 { //nolint:mnd - break - } + sn := NewPartialStorageNode(edge.Path, edge.Child) - // Set the value of the current node - parentNode.Value = pNode.Hash(hashF) + // Handle leaf edge case (single key trie) + if edge.Path.Len() == key.Len() { + if err := nodes.Put(*sn.key, sn); err != nil { + return nil, nil, fmt.Errorf("failed to store leaf edge: %w", err) + } + return sn.Key(), sn.Value(), nil + } - // Set the child key of the current node. - childKey, err := assignChild(i, compressParent, &parentNode, &nilKey, leafKey, parentKey, proofNodes, hashF) - if err != nil { - return nil, err + // Handle edge + binary case (double key trie) + child, ok := proof.Get(*edge.Child) + if !ok { + return nil, nil, fmt.Errorf("edge child not found: %s", edge.Child) } - // Set the LeftHash and RightHash values - parentNode.LeftHash, parentNode.RightHash, err = getLeftRightHash(i, proofNodes) - if err != nil { - return nil, err + binary, ok := child.(*Binary) + if !ok { + return nil, nil, fmt.Errorf("expected binary node as child, got: %T", child) } - pathNodes = append(pathNodes, StorageNode{key: parentKey, node: &parentNode}) + sn.node.LeftHash = binary.LeftHash + sn.node.RightHash = binary.RightHash - // break early since we don't store leafs along proof paths, or if no more nodes exist along the proof paths - if childKey.len == 0 || childKey.len == 251 { - break + if err := nodes.Put(*sn.key, sn); err != nil { + return nil, nil, fmt.Errorf("failed to store edge+binary: %w", err) } + rootKey = sn.Key() } - return pathNodes, nil + return rootKey, val, nil } -func skipNode(pNode ProofNode, pathNodes []StorageNode, hashF hashFunc) bool { - lastNode := pathNodes[len(pathNodes)-1].node - noLeftMatch, noRightMatch := false, false - if lastNode.LeftHash != nil && !pNode.Hash(hashF).Equal(lastNode.LeftHash) { - noLeftMatch = true - } - if lastNode.RightHash != nil && !pNode.Hash(hashF).Equal(lastNode.RightHash) { - noRightMatch = true - } - if noLeftMatch && noRightMatch { - return true +// buildPath recursively builds the path for a given node hash, key, and current position. +// It returns the current node's key and any leaf value found along this path. +func buildPath( + nodeHash *felt.Felt, + key *Key, + curPos uint8, + curNode *StorageNode, + proof *ProofNodeSet, + nodes *StorageNodeSet, +) (*Key, *felt.Felt, error) { + // We reached the leaf + if curPos == key.Len() { + leafKey := key.Copy() + leafNode := NewPartialStorageNode(&leafKey, nodeHash) + if err := nodes.Put(leafKey, leafNode); err != nil { + return nil, nil, err + } + return leafNode.Key(), leafNode.Value(), nil } - return false -} -func getLeftRightHash(parentInd int, proofNodes []ProofNode) (*felt.Felt, *felt.Felt, error) { - parent := proofNodes[parentInd] + proofNode, ok := proof.Get(*nodeHash) + if !ok { // non-existent proof node + return NilKey, nil, nil + } - switch parent := parent.(type) { + switch pn := proofNode.(type) { case *Binary: - return parent.LeftHash, parent.RightHash, nil + return handleBinaryNode(pn, nodeHash, key, curPos, curNode, proof, nodes) case *Edge: - if parentInd+1 > len(proofNodes)-1 { - return nil, nil, errors.New("cant get hash of children from proof node, out of range") - } - parentBinary := proofNodes[parentInd+1].(*Binary) - return parentBinary.LeftHash, parentBinary.RightHash, nil - default: - return nil, nil, fmt.Errorf("%w: %T", ErrUnknownProofNode, parent) + return handleEdgeNode(pn, key, curPos, proof, nodes) } -} -func getParentKey(idx int, compressedParentOffset uint8, leafKey *Key, - pNode ProofNode, pathNodes []StorageNode, proofNodes []ProofNode, -) (*Key, error) { - var crntKey *Key - var err error + return nil, nil, nil +} - var height uint8 - if len(pathNodes) > 0 { - if p, ok := proofNodes[idx].(*Edge); ok { - height = pathNodes[len(pathNodes)-1].key.len + p.Path.len - } else { - height = pathNodes[len(pathNodes)-1].key.len + 1 +// handleBinaryNode processes a binary node in the proof path by creating/updating a storage node, +// setting its left/right hashes, and recursively building the path for the appropriate child direction. +// It returns the current node's key and any leaf value found along this path. +func handleBinaryNode( + binary *Binary, + nodeHash *felt.Felt, + key *Key, + curPos uint8, + curNode *StorageNode, + proof *ProofNodeSet, + nodes *StorageNodeSet, +) (*Key, *felt.Felt, error) { + // If curNode is nil, it means that this current binary node is the root node. + // Or, it's an internal binary node and the parent is also a binary node. + // A standalone binary proof node always corresponds to a single storage node. + // If curNode is not nil, it means that the parent node is an edge node. + // In this case, the key of the storage node is based on the parent edge node. + if curNode == nil { + nodeKey, err := key.MostSignificantBits(curPos) + if err != nil { + return nil, nil, err } + curNode = NewPartialStorageNode(nodeKey, nodeHash) } + curNode.node.LeftHash = binary.LeftHash + curNode.node.RightHash = binary.RightHash - if _, ok := pNode.(*Binary); ok { - crntKey, err = leafKey.SubKey(height) + // Calculate next position and determine to take left or right path + nextPos := curPos + 1 + isRightPath := key.IsBitSet(key.Len() - nextPos) + nextHash := binary.LeftHash + if isRightPath { + nextHash = binary.RightHash + } + + childKey, val, err := buildPath(nextHash, key, nextPos, nil, proof, nodes) + if err != nil { + return nil, nil, err + } + + // Set child reference + if isRightPath { + curNode.node.Right = childKey } else { - crntKey, err = leafKey.SubKey(height + compressedParentOffset) + curNode.node.Left = childKey + } + + if err := nodes.Put(*curNode.key, curNode); err != nil { + return nil, nil, fmt.Errorf("failed to store binary node: %w", err) } - return crntKey, err + + return curNode.Key(), val, nil } -func getChildKey(childIdx int, crntKey, leafKey, nilKey *Key, proofNodes []ProofNode, hashF hashFunc) (*Key, error) { - if childIdx > len(proofNodes)-1 { - return nilKey, nil +// handleEdgeNode processes an edge node in the proof path by verifying the edge path matches +// the key path and either creating a leaf node or continuing to traverse the trie. It returns +// the current node's key and any leaf value found along this path. +func handleEdgeNode( + edge *Edge, + key *Key, + curPos uint8, + proof *ProofNodeSet, + nodes *StorageNodeSet, +) (*Key, *felt.Felt, error) { + // Verify the edge path matches the key path + if !verifyEdgePath(key, edge.Path, curPos) { + return NilKey, nil, nil + } + + // The next node position is the end of the edge path + nextPos := curPos + edge.Path.Len() + nodeKey, err := key.MostSignificantBits(nextPos) + if err != nil { + return nil, nil, fmt.Errorf("failed to get MSB for internal edge: %w", err) } + curNode := NewPartialStorageNode(nodeKey, edge.Child) - compressChild, compressChildOffset, err := compressNode(childIdx, proofNodes, hashF) + // This is an edge leaf, stop traversing the trie + if nextPos == key.Len() { + if err := nodes.Put(*curNode.key, curNode); err != nil { + return nil, nil, fmt.Errorf("failed to store edge leaf: %w", err) + } + return curNode.Key(), curNode.Value(), nil + } + + _, val, err := buildPath(edge.Child, key, nextPos, curNode, proof, nodes) if err != nil { - return nil, err + return nil, nil, fmt.Errorf("failed to build child path: %w", err) } - if crntKey.len+uint8(compressChild)+compressChildOffset == 251 { //nolint:mnd - return nilKey, nil + if err := nodes.Put(*curNode.key, curNode); err != nil { + return nil, nil, fmt.Errorf("failed to store internal edge: %w", err) } - return leafKey.SubKey(crntKey.len + uint8(compressChild) + compressChildOffset) + return curNode.Key(), val, nil } -// BuildTrie builds a trie using the proof paths (including inner nodes), and then sets all the keys-values (leaves) -func BuildTrie(leftProofPath, rightProofPath []StorageNode, keys, values []*felt.Felt) (*Trie, error) { //nolint:gocyclo - tempTrie, err := NewTriePedersen(newMemStorage(), 251) //nolint:mnd +// verifyEdgePath checks if the edge path matches the key path at the current position. +func verifyEdgePath(key, edgePath *Key, curPos uint8) bool { + if key.Len() < curPos+edgePath.Len() { + return false + } + + // Ensure the bits between segment of the key and the node path match + start := key.Len() - curPos - edgePath.Len() + end := key.Len() - curPos + for i := start; i < end; i++ { + if key.IsBitSet(i) != edgePath.IsBitSet(i-start) { + return false // paths diverge - this proves non-membership + } + } + return true +} + +// buildTrie builds a trie from a list of storage nodes and a list of keys and values. +func buildTrie(height uint8, rootKey *Key, nodes []*StorageNode, keys, values []*felt.Felt) (*Trie, error) { + tr, err := NewTriePedersen(newMemStorage(), height) if err != nil { return nil, err } - // merge proof paths - for i := range min(len(leftProofPath), len(rightProofPath)) { - // Can't store nil keys so stop merging - if leftProofPath[i].node.Left == nil || leftProofPath[i].node.Right == nil || - rightProofPath[i].node.Left == nil || rightProofPath[i].node.Right == nil { - break - } - if leftProofPath[i].key.Equal(rightProofPath[i].key) { - leftProofPath[i].node.Right = rightProofPath[i].node.Right - rightProofPath[i].node.Left = leftProofPath[i].node.Left - } else { - break + tr.setRootKey(rootKey) + + // Nodes are inserted in reverse order because the leaf nodes are placed at the front of the list. + // We would want to insert root node first so the root key is set first. + for i := len(nodes) - 1; i >= 0; i-- { + if err := tr.PutInner(nodes[i].key, nodes[i].node); err != nil { + return nil, err } } - for _, sNode := range leftProofPath { - if sNode.node.Left == nil || sNode.node.Right == nil { - break - } - _, err := tempTrie.PutInner(sNode.key, sNode.node) + for index, key := range keys { + _, err = tr.PutWithProof(key, values[index], nodes) if err != nil { return nil, err } } - for _, sNode := range rightProofPath { - if sNode.node.Left == nil || sNode.node.Right == nil { - break + return tr, nil +} + +// hasRightElement checks if there is a right sibling for the given key in the trie. +// This function assumes that the entire path has been resolved. +func hasRightElement(rootKey, key *Key, nodes *StorageNodeSet) bool { + cur := rootKey + for cur != nil && !cur.Equal(NilKey) { + sn, ok := nodes.Get(*cur) + if !ok { + return false } - _, err := tempTrie.PutInner(sNode.key, sNode.node) - if err != nil { - return nil, err + + // We resolved the entire path, no more elements + if key.Equal(cur) { + return false } - } - for i := range len(keys) { - _, err := tempTrie.PutWithProof(keys[i], values[i], leftProofPath, rightProofPath) - if err != nil { - return nil, err + // If we're taking a left path and there's a right sibling, + // then there are elements with larger values + bitPos := key.Len() - cur.Len() - 1 + isLeft := !key.IsBitSet(bitPos) + if isLeft && sn.node.RightHash != nil { + return true + } + + // Move to next node based on the path + cur = sn.node.Right + if isLeft { + cur = sn.node.Left } } - return tempTrie, nil + + return false } diff --git a/core/trie/proof_test.go b/core/trie/proof_test.go index e6d8576d83..94eaabc549 100644 --- a/core/trie/proof_test.go +++ b/core/trie/proof_test.go @@ -1,6 +1,8 @@ package trie_test import ( + "math/rand" + "sort" "testing" "github.com/NethermindEth/juno/core/crypto" @@ -8,1132 +10,818 @@ import ( "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db/pebble" "github.com/NethermindEth/juno/utils" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func buildSimpleTrie(t *testing.T) *trie.Trie { - // (250, 0, x1) edge - // | - // (0,0,x1) binary - // / \ - // (2) (3) - // Build trie - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) - - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) - require.NoError(t, err) +func TestProve(t *testing.T) { + t.Parallel() - // Update trie - key1 := new(felt.Felt).SetUint64(0) - key2 := new(felt.Felt).SetUint64(1) - value1 := new(felt.Felt).SetUint64(2) - value2 := new(felt.Felt).SetUint64(3) + n := 1000 + tempTrie, records := nonRandomTrie(t, n) - _, err = tempTrie.Put(key1, value1) - require.NoError(t, err) + for _, record := range records { + proofSet := trie.NewProofNodeSet() + err := tempTrie.Prove(record.key, proofSet) + require.NoError(t, err) - _, err = tempTrie.Put(key2, value2) - require.NoError(t, err) + root, err := tempTrie.Root() + require.NoError(t, err) - require.NoError(t, tempTrie.Commit()) - return tempTrie + val, err := trie.VerifyProof(root, record.key, proofSet, crypto.Pedersen) + if err != nil { + t.Fatalf("failed for key %s", record.key.String()) + } + require.Equal(t, record.value, val) + } } -func buildSimpleBinaryRootTrie(t *testing.T) *trie.Trie { - // PF - // (0, 0, x) - // / \ - // (250, 0, cc) (250, 11111.., dd) - // | | - // (cc) (dd) - - // JUNO - // (0, 0, x) - // / \ - // (251, 0, cc) (251, 11111.., dd) - - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) +func TestProveNonExistent(t *testing.T) { + t.Parallel() - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{1}), 251) - require.NoError(t, err) + n := 1000 + tempTrie, _ := nonRandomTrie(t, n) - key1 := new(felt.Felt).SetUint64(0) - key2 := utils.HexToFelt(t, "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") - value1 := utils.HexToFelt(t, "0xcc") - value2 := utils.HexToFelt(t, "0xdd") + for i := 1; i < n+1; i++ { + keyFelt := new(felt.Felt).SetUint64(uint64(i + n)) - _, err = tempTrie.Put(key1, value1) - require.NoError(t, err) + proofSet := trie.NewProofNodeSet() + err := tempTrie.Prove(keyFelt, proofSet) + require.NoError(t, err) - _, err = tempTrie.Put(key2, value2) - require.NoError(t, err) + root, err := tempTrie.Root() + require.NoError(t, err) - require.NoError(t, tempTrie.Commit()) - return tempTrie + val, err := trie.VerifyProof(root, keyFelt, proofSet, crypto.Pedersen) + if err != nil { + t.Fatalf("failed for key %s", keyFelt.String()) + } + require.Equal(t, &felt.Zero, val) + } } -func buildSimpleDoubleBinaryTrie(t *testing.T) (*trie.Trie, []trie.ProofNode) { - // (249,0,x3) // Edge - // | - // (0, 0, x3) // Binary - // / \ - // (0,0,x1) // B (1, 1, 5) // Edge leaf - // / \ | - // (2) (3) (5) - - // Build trie - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) - - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{1}), 251) - require.NoError(t, err) - - // Update trie - key1 := new(felt.Felt).SetUint64(0) - key2 := new(felt.Felt).SetUint64(1) - key3 := new(felt.Felt).SetUint64(3) - value1 := new(felt.Felt).SetUint64(2) - value2 := new(felt.Felt).SetUint64(3) - value3 := new(felt.Felt).SetUint64(5) - - _, err = tempTrie.Put(key1, value1) - require.NoError(t, err) +func TestProveRandom(t *testing.T) { + t.Parallel() + tempTrie, records := randomTrie(t, 1000) - _, err = tempTrie.Put(key2, value2) - require.NoError(t, err) + for _, record := range records { + proofSet := trie.NewProofNodeSet() + err := tempTrie.Prove(record.key, proofSet) + require.NoError(t, err) - _, err = tempTrie.Put(key3, value3) - require.NoError(t, err) + root, err := tempTrie.Root() + require.NoError(t, err) - require.NoError(t, tempTrie.Commit()) + val, err := trie.VerifyProof(root, record.key, proofSet, crypto.Pedersen) + require.NoError(t, err) + require.Equal(t, record.value, val) + } +} - zero := trie.NewKey(249, []byte{0}) - key3Bytes := new(felt.Felt).SetUint64(1).Bytes() - path3 := trie.NewKey(1, key3Bytes[:]) - expectedProofNodes := []trie.ProofNode{ - &trie.Edge{ - Path: &zero, - Child: utils.HexToFelt(t, "0x055C81F6A791FD06FC2E2CCAD922397EC76C3E35F2E06C0C0D43D551005A8DEA"), +func TestProveCustom(t *testing.T) { + t.Parallel() + + tests := []testTrie{ + { + name: "simple binary", + buildFn: buildSimpleTrie, + testKeys: []testKey{ + { + name: "prove existing key", + key: new(felt.Felt).SetUint64(1), + expected: new(felt.Felt).SetUint64(3), + }, + }, }, - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), - RightHash: utils.HexToFelt(t, "0x07C5BC1CC68B7BC8CA2F632DE98297E6DA9594FA23EDE872DD2ABEAFDE353B43"), + { + name: "simple double binary", + buildFn: buildSimpleDoubleBinaryTrie, + testKeys: []testKey{ + { + name: "prove existing key 0", + key: new(felt.Felt).SetUint64(0), + expected: new(felt.Felt).SetUint64(2), + }, + { + name: "prove existing key 3", + key: new(felt.Felt).SetUint64(3), + expected: new(felt.Felt).SetUint64(5), + }, + { + name: "prove non-existent key 2", + key: new(felt.Felt).SetUint64(2), + expected: new(felt.Felt).SetUint64(0), + }, + { + name: "prove non-existent key 123", + key: new(felt.Felt).SetUint64(123), + expected: new(felt.Felt).SetUint64(0), + }, + }, }, - &trie.Edge{ - Path: &path3, - Child: value3, + { + name: "simple binary root", + buildFn: buildSimpleBinaryRootTrie, + testKeys: []testKey{ + { + name: "prove existing key", + key: new(felt.Felt).SetUint64(0), + expected: utils.HexToFelt(t, "0xcc"), + }, + }, + }, + { + name: "left-right edge", + buildFn: func(t *testing.T) (*trie.Trie, []*keyValue) { + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) + + tr, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{1}), 251) + require.NoError(t, err) + + records := []*keyValue{ + {key: utils.HexToFelt(t, "0xff"), value: utils.HexToFelt(t, "0xaa")}, + } + + for _, record := range records { + _, err = tr.Put(record.key, record.value) + require.NoError(t, err) + } + require.NoError(t, tr.Commit()) + return tr, records + }, + testKeys: []testKey{ + { + name: "prove existing key", + key: utils.HexToFelt(t, "0xff"), + expected: utils.HexToFelt(t, "0xaa"), + }, + }, + }, + { + name: "three key trie", + buildFn: build3KeyTrie, + testKeys: []testKey{ + { + name: "prove existing key", + key: new(felt.Felt).SetUint64(2), + expected: new(felt.Felt).SetUint64(6), + }, + }, }, } - return tempTrie, expectedProofNodes -} - -func build3KeyTrie(t *testing.T) *trie.Trie { - // Starknet - // -------- - // - // Edge - // | - // Binary with len 249 parent - // / \ - // Binary (250) Edge with len 250 - // / \ / - // 0x4 0x5 0x6 child - - // Juno - // ---- - // - // Node (path 249) - // / \ - // Node (binary) \ - // / \ / - // 0x4 0x5 0x6 - - // Build trie - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) - - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) - require.NoError(t, err) - - // Update trie - key1 := new(felt.Felt).SetUint64(0) - key2 := new(felt.Felt).SetUint64(1) - key3 := new(felt.Felt).SetUint64(2) - value1 := new(felt.Felt).SetUint64(4) - value2 := new(felt.Felt).SetUint64(5) - value3 := new(felt.Felt).SetUint64(6) - - _, err = tempTrie.Put(key1, value1) - require.NoError(t, err) - - _, err = tempTrie.Put(key3, value3) - require.NoError(t, err) - _, err = tempTrie.Put(key2, value2) - require.NoError(t, err) - - require.NoError(t, tempTrie.Commit()) - return tempTrie -} + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() -func build4KeyTrie(t *testing.T) *trie.Trie { - // Juno - // 248 - // / \ - // 249 \ - // / \ \ - // 250 \ \ - // / \ /\ /\ - // 0 1 2 4 - - // Juno - should be able to reconstruct this from proofs - // 248 - // / \ - // 249 // Note we cant derive the right key, but need to store it's hash - // / \ - // 250 \ - // / \ / (Left hash set, no key) - // 0 - - // Pathfinder (???) - // 0 Edge - // | - // 248 Binary - // / \ - // 249 \ Binary Edge ?? - // / \ \ - // 250 250 250 Binary Edge ?? - // / \ / / - // 0 1 2 4 - - // Build trie - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) + tr, _ := test.buildFn(t) - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) - require.NoError(t, err) + for _, tc := range test.testKeys { + t.Run(tc.name, func(t *testing.T) { + proofSet := trie.NewProofNodeSet() + err := tr.Prove(tc.key, proofSet) + require.NoError(t, err) - // Update trie - key1 := new(felt.Felt).SetUint64(0) - key2 := new(felt.Felt).SetUint64(1) - key3 := new(felt.Felt).SetUint64(2) - key5 := new(felt.Felt).SetUint64(4) - value1 := new(felt.Felt).SetUint64(4) - value2 := new(felt.Felt).SetUint64(5) - value3 := new(felt.Felt).SetUint64(6) - value5 := new(felt.Felt).SetUint64(7) - - _, err = tempTrie.Put(key1, value1) - require.NoError(t, err) - - _, err = tempTrie.Put(key3, value3) - require.NoError(t, err) - _, err = tempTrie.Put(key2, value2) - require.NoError(t, err) - _, err = tempTrie.Put(key5, value5) - require.NoError(t, err) - - require.NoError(t, tempTrie.Commit()) - - return tempTrie -} - -func noDuplicates(proofNodes []trie.ProofNode) bool { - seen := make(map[felt.Felt]bool) - for _, pNode := range proofNodes { - if _, ok := seen[*pNode.Hash(crypto.Pedersen)]; ok { - return false - } - seen[*pNode.Hash(crypto.Pedersen)] = true - } - return true -} + root, err := tr.Root() + require.NoError(t, err) -// containsAll checks that subsetProofNodes is a subset of proofNodes -func containsAll(proofNodes, subsetProofNodes []trie.ProofNode) bool { - for _, pNode := range subsetProofNodes { - found := false - for _, p := range proofNodes { - if p.Hash(crypto.Pedersen).Equal(pNode.Hash(crypto.Pedersen)) { - found = true - break + val, err := trie.VerifyProof(root, tc.key, proofSet, crypto.Pedersen) + require.NoError(t, err) + require.Equal(t, tc.expected, val) + }) } - } - if !found { - return false - } - } - return true -} - -func isSameProofPath(proofNodes, expectedProofNodes []trie.ProofNode) bool { - if len(proofNodes) != len(expectedProofNodes) { - return false - } - for i := range proofNodes { - if !proofNodes[i].Hash(crypto.Pedersen).Equal(expectedProofNodes[i].Hash(crypto.Pedersen)) { - return false - } + }) } - return true } -func newBinaryProofNode() *trie.Binary { - return &trie.Binary{ - LeftHash: new(felt.Felt).SetUint64(1), - RightHash: new(felt.Felt).SetUint64(2), - } -} +// TestRangeProof tests normal range proof with both edge proofs +func TestRangeProof(t *testing.T) { + t.Parallel() -func TestGetProof(t *testing.T) { - t.Run("GP Simple Trie - simple binary", func(t *testing.T) { - tempTrie := buildSimpleTrie(t) + n := 500 + tr, records := randomTrie(t, n) + root, err := tr.Root() + require.NoError(t, err) - zero := trie.NewKey(250, []byte{0}) - expectedProofNodes := []trie.ProofNode{ - &trie.Edge{ - Path: &zero, - Child: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), - }, + for i := 0; i < 100; i++ { + start := rand.Intn(n) + end := rand.Intn(n-start) + start + 1 - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002"), - RightHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003"), - }, - } - leafFelt := new(felt.Felt).SetUint64(0).Bytes() - leafKey := trie.NewKey(251, leafFelt[:]) - proofNodes, err := trie.GetProof(&leafKey, tempTrie) + proof := trie.NewProofNodeSet() + err := tr.GetRangeProof(records[start].key, records[end-1].key, proof) require.NoError(t, err) - // Better inspection - // for _, pNode := range proofNodes { - // pNode.PrettyPrint() - // } - require.Equal(t, expectedProofNodes, proofNodes) - }) - - t.Run("GP Simple Trie - simple double binary", func(t *testing.T) { - tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) - - expectedProofNodes[2] = &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002"), - RightHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003"), + keys := []*felt.Felt{} + values := []*felt.Felt{} + for i := start; i < end; i++ { + keys = append(keys, records[i].key) + values = append(values, records[i].value) } - leafFelt := new(felt.Felt).SetUint64(0).Bytes() - leafKey := trie.NewKey(251, leafFelt[:]) - proofNodes, err := trie.GetProof(&leafKey, tempTrie) + _, err = trie.VerifyRangeProof(root, records[start].key, keys, values, proof) require.NoError(t, err) + } +} - // Better inspection - // for _, pNode := range proofNodes { - // pNode.PrettyPrint() - // } - require.Equal(t, expectedProofNodes, proofNodes) - }) - - t.Run("GP Simple Trie - simple double binary edge", func(t *testing.T) { - tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) - leafFelt := new(felt.Felt).SetUint64(3).Bytes() - leafKey := trie.NewKey(251, leafFelt[:]) - proofNodes, err := trie.GetProof(&leafKey, tempTrie) - require.NoError(t, err) +// TestRangeProofWithNonExistentProof tests normal range proof with non-existent proofs +func TestRangeProofWithNonExistentProof(t *testing.T) { + t.Parallel() - // Better inspection - // for _, pNode := range proofNodes { - // pNode.PrettyPrint() - // } - require.Equal(t, expectedProofNodes, proofNodes) - }) + n := 500 + tr, records := randomTrie(t, n) + root, err := tr.Root() + require.NoError(t, err) - t.Run("GP Simple Trie - simple binary root", func(t *testing.T) { - tempTrie := buildSimpleBinaryRootTrie(t) + for i := 0; i < 100; i++ { + start := rand.Intn(n) + end := rand.Intn(n-start) + start + 1 - key1Bytes := new(felt.Felt).SetUint64(0).Bytes() - path1 := trie.NewKey(250, key1Bytes[:]) - expectedProofNodes := []trie.ProofNode{ - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x06E08BF82793229338CE60B65D1845F836C8E2FBFE2BC59FF24AEDBD8BA219C4"), - RightHash: utils.HexToFelt(t, "0x04F9B8E66212FB528C0C1BD02F43309C53B895AA7D9DC91180001BDD28A588FA"), - }, - &trie.Edge{ - Path: &path1, - Child: utils.HexToFelt(t, "0xcc"), - }, + first := decrementFelt(records[start].key) + if start != 0 && first.Equal(records[start-1].key) { + continue } - leafFelt := new(felt.Felt).SetUint64(0).Bytes() - leafKey := trie.NewKey(251, leafFelt[:]) - - proofNodes, err := trie.GetProof(&leafKey, tempTrie) - require.NoError(t, err) - // Better inspection - // for _, pNode := range proofNodes { - // pNode.PrettyPrint() - // } - require.Equal(t, expectedProofNodes, proofNodes) - }) - - t.Run("GP Simple Trie - left-right edge", func(t *testing.T) { - // (251,0xff,0xaa) - // / - // \ - // (0xaa) - memdb := pebble.NewMemTest(t) - txn, err := memdb.NewTransaction(true) - require.NoError(t, err) - - tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{1}), 251) - require.NoError(t, err) - - key1 := utils.HexToFelt(t, "0xff") - value1 := utils.HexToFelt(t, "0xaa") - - _, err = tempTrie.Put(key1, value1) + proof := trie.NewProofNodeSet() + err := tr.GetRangeProof(first, records[end-1].key, proof) require.NoError(t, err) - require.NoError(t, tempTrie.Commit()) - - key1Bytes := key1.Bytes() - path1 := trie.NewKey(251, key1Bytes[:]) - - child := utils.HexToFelt(t, "0x00000000000000000000000000000000000000000000000000000000000000AA") - expectedProofNodes := []trie.ProofNode{ - &trie.Edge{ - Path: &path1, - Child: child, - }, + keys := make([]*felt.Felt, end-start) + values := make([]*felt.Felt, end-start) + for i := start; i < end; i++ { + keys[i-start] = records[i].key + values[i-start] = records[i].value } - leafFelt := new(felt.Felt).SetUint64(0).Bytes() - leafKey := trie.NewKey(251, leafFelt[:]) - proofNodes, err := trie.GetProof(&leafKey, tempTrie) - require.NoError(t, err) - - // Better inspection - // for _, pNode := range proofNodes { - // pNode.PrettyPrint() - // } - require.Equal(t, expectedProofNodes, proofNodes) - }) - - t.Run("GP Simple Trie - proof for non-set key", func(t *testing.T) { - tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) - leafFelt := new(felt.Felt).SetUint64(123).Bytes() // The (root) edge node would have a shorter len if this key was set - leafKey := trie.NewKey(251, leafFelt[:]) - proofNodes, err := trie.GetProof(&leafKey, tempTrie) + _, err = trie.VerifyRangeProof(root, first, keys, values, proof) require.NoError(t, err) + } +} - // Better inspection - for _, pNode := range proofNodes { - pNode.PrettyPrint() - } - require.Equal(t, expectedProofNodes[0:2], proofNodes) - }) +// TestRangeProofWithInvalidNonExistentProof tests range proof with invalid non-existent proofs. +// One scenario is when there is a gap between the first element and the left edge proof. +func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { + t.Parallel() - t.Run("GP Simple Trie - proof for inner key", func(t *testing.T) { - tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) + n := 500 + tr, records := randomTrie(t, n) + root, err := tr.Root() + require.NoError(t, err) - innerFelt := new(felt.Felt).SetUint64(2).Bytes() - innerKey := trie.NewKey(123, innerFelt[:]) // The (root) edge node has len 249 which shows this doesn't exist - proofNodes, err := trie.GetProof(&innerKey, tempTrie) - require.NoError(t, err) + start, end := 100, 200 + first := decrementFelt(records[start].key) - // Better inspection - for _, pNode := range proofNodes { - pNode.PrettyPrint() - } - require.Equal(t, expectedProofNodes[0:2], proofNodes) - }) - - t.Run("GP Simple Trie - proof for non-set key, with leafs set to right and left", func(t *testing.T) { - tempTrie, expectedProofNodes := buildSimpleDoubleBinaryTrie(t) + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(first, records[end-1].key, proof) + require.NoError(t, err) - leafFelt := new(felt.Felt).SetUint64(2).Bytes() - leafKey := trie.NewKey(251, leafFelt[:]) - proofNodes, err := trie.GetProof(&leafKey, tempTrie) - require.NoError(t, err) + start = 105 // Gap created + keys := make([]*felt.Felt, end-start) + values := make([]*felt.Felt, end-start) + for i := start; i < end; i++ { + keys[i-start] = records[i].key + values[i-start] = records[i].value + } - // Better inspection - for _, pNode := range proofNodes { - pNode.PrettyPrint() - } - require.Equal(t, expectedProofNodes, proofNodes) - }) + _, err = trie.VerifyRangeProof(root, first, keys, values, proof) + require.Error(t, err) } -func TestVerifyProof(t *testing.T) { - // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2137 - t.Run("VP Simple binary trie", func(t *testing.T) { - tempTrie := buildSimpleTrie(t) - zero := trie.NewKey(250, []byte{0}) - expectedProofNodes := []trie.ProofNode{ - &trie.Edge{ - Path: &zero, - Child: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), - }, - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002"), - RightHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003"), - }, - } - - root, err := tempTrie.Root() - require.NoError(t, err) - val1 := new(felt.Felt).SetUint64(2) - - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - leafkey := trie.NewKey(251, zeroFeltBytes[:]) - assert.True(t, trie.VerifyProof(root, &leafkey, val1, expectedProofNodes, crypto.Pedersen)) - }) - - // https://github.com/eqlabs/pathfinder/blob/main/crates/merkle-tree/src/tree.rs#L2167 - t.Run("VP Simple double binary trie", func(t *testing.T) { - tempTrie, _ := buildSimpleDoubleBinaryTrie(t) - zero := trie.NewKey(249, []byte{0}) - expectedProofNodes := []trie.ProofNode{ - &trie.Edge{ - Path: &zero, - Child: utils.HexToFelt(t, "0x055C81F6A791FD06FC2E2CCAD922397EC76C3E35F2E06C0C0D43D551005A8DEA"), - }, - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), - RightHash: utils.HexToFelt(t, "0x07C5BC1CC68B7BC8CA2F632DE98297E6DA9594FA23EDE872DD2ABEAFDE353B43"), - }, - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002"), - RightHash: utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003"), - }, - } +func TestOneElementRangeProof(t *testing.T) { + t.Parallel() - root, err := tempTrie.Root() - require.NoError(t, err) - val1 := new(felt.Felt).SetUint64(2) - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - leafkey := trie.NewKey(251, zeroFeltBytes[:]) - assert.True(t, trie.VerifyProof(root, &leafkey, val1, expectedProofNodes, crypto.Pedersen)) - }) + n := 1000 + tr, records := randomTrie(t, n) + root, err := tr.Root() + require.NoError(t, err) - t.Run("VP three key trie", func(t *testing.T) { - tempTrie := build3KeyTrie(t) - zero := trie.NewKey(249, []byte{0}) - felt2 := new(felt.Felt).SetUint64(0).Bytes() - lastPath := trie.NewKey(1, felt2[:]) - expectedProofNodes := []trie.ProofNode{ - &trie.Edge{ - Path: &zero, - Child: utils.HexToFelt(t, "0x0768DEB8D0795D80AAAC2E5E326141F33044759F97A1BF092D8EB9C4E4BE9234"), - }, - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x057166F9476D0A2D6875124251841EB85A9AE37462FAE3CBF7304BCD593938E7"), - RightHash: utils.HexToFelt(t, "0x060FBDE29F96F706498EFD132DC7F312A4C99A9AE051BF152C2AF2B3CAF31E5B"), - }, - &trie.Edge{ - Path: &lastPath, - Child: utils.HexToFelt(t, "0x6"), - }, - } + t.Run("both edge proofs with the same key", func(t *testing.T) { + t.Parallel() - root, err := tempTrie.Root() + start := 100 + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(records[start].key, records[start].key, proof) require.NoError(t, err) - val6 := new(felt.Felt).SetUint64(6) - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - leafkey := trie.NewKey(251, twoFeltBytes[:]) - gotProof, err := trie.GetProof(&leafkey, tempTrie) + _, err = trie.VerifyRangeProof(root, records[start].key, []*felt.Felt{records[start].key}, []*felt.Felt{records[start].value}, proof) require.NoError(t, err) - require.Equal(t, expectedProofNodes, gotProof) - - assert.True(t, trie.VerifyProof(root, &leafkey, val6, expectedProofNodes, crypto.Pedersen)) }) - t.Run("VP non existent key - less than root edge", func(t *testing.T) { - tempTrie, _ := buildSimpleDoubleBinaryTrie(t) + t.Run("left non-existent edge proof", func(t *testing.T) { + t.Parallel() - nonExistentKey := trie.NewKey(123, []byte{0}) // Diverges before the root node (len root node = 249) - nonExistentKeyValue := new(felt.Felt).SetUint64(2) - proofNodes, err := trie.GetProof(&nonExistentKey, tempTrie) + start := 100 + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(decrementFelt(records[start].key), records[start].key, proof) require.NoError(t, err) - root, err := tempTrie.Root() + _, err = trie.VerifyRangeProof(root, decrementFelt(records[start].key), []*felt.Felt{records[start].key}, []*felt.Felt{records[start].value}, proof) require.NoError(t, err) - - require.False(t, trie.VerifyProof(root, &nonExistentKey, nonExistentKeyValue, proofNodes, crypto.Pedersen)) }) - t.Run("VP non existent leaf key", func(t *testing.T) { - tempTrie, _ := buildSimpleDoubleBinaryTrie(t) + t.Run("right non-existent edge proof", func(t *testing.T) { + t.Parallel() - nonExistentKeyByte := new(felt.Felt).SetUint64(2).Bytes() // Key not set - nonExistentKey := trie.NewKey(251, nonExistentKeyByte[:]) - nonExistentKeyValue := new(felt.Felt).SetUint64(2) - proofNodes, err := trie.GetProof(&nonExistentKey, tempTrie) + end := 100 + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(records[end].key, incrementFelt(records[end].key), proof) require.NoError(t, err) - root, err := tempTrie.Root() + _, err = trie.VerifyRangeProof(root, records[end].key, []*felt.Felt{records[end].key}, []*felt.Felt{records[end].value}, proof) require.NoError(t, err) - - require.False(t, trie.VerifyProof(root, &nonExistentKey, nonExistentKeyValue, proofNodes, crypto.Pedersen)) }) -} -func TestProofToPath(t *testing.T) { - t.Run("PTP Proof To Path Simple binary trie proof to path", func(t *testing.T) { - tempTrie := buildSimpleTrie(t) - zeroFeltByte := new(felt.Felt).Bytes() - zero := trie.NewKey(250, zeroFeltByte[:]) - leafValue := utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000002") - siblingValue := utils.HexToFelt(t, "0x0000000000000000000000000000000000000000000000000000000000000003") - proofNodes := []trie.ProofNode{ - &trie.Edge{ - Path: &zero, - Child: utils.HexToFelt(t, "0x05774FA77B3D843AE9167ABD61CF80365A9B2B02218FC2F628494B5BDC9B33B8"), - }, - &trie.Binary{ - LeftHash: leafValue, - RightHash: siblingValue, - }, - } + t.Run("both non-existent edge proofs", func(t *testing.T) { + t.Parallel() - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - leafkey := trie.NewKey(251, zeroFeltBytes[:]) - sns, err := trie.ProofToPath(proofNodes, &leafkey, crypto.Pedersen) + start := 100 + first, last := decrementFelt(records[start].key), incrementFelt(records[start].key) + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(first, last, proof) require.NoError(t, err) - rootKey := tempTrie.RootKey() - - require.Equal(t, 1, len(sns)) - require.Equal(t, rootKey.Len(), sns[0].Key().Len()) - require.Equal(t, leafValue.String(), sns[0].Node().LeftHash.String()) - require.Equal(t, siblingValue.String(), sns[0].Node().RightHash.String()) - }) - - t.Run("PTP Simple double binary trie proof to path", func(t *testing.T) { - tempTrie := buildSimpleBinaryRootTrie(t) - - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - leafkey := trie.NewKey(251, zeroFeltBytes[:]) - path1 := trie.NewKey(250, zeroFeltBytes[:]) - proofNodes := []trie.ProofNode{ - &trie.Binary{ - LeftHash: utils.HexToFelt(t, "0x06E08BF82793229338CE60B65D1845F836C8E2FBFE2BC59FF24AEDBD8BA219C4"), - RightHash: utils.HexToFelt(t, "0x04F9B8E66212FB528C0C1BD02F43309C53B895AA7D9DC91180001BDD28A588FA"), - }, - &trie.Edge{ - Path: &path1, - Child: utils.HexToFelt(t, "0xcc"), - }, - } - - siblingValue := utils.HexToFelt(t, "0xdd") - sns, err := trie.ProofToPath(proofNodes, &leafkey, crypto.Pedersen) - require.NoError(t, err) - rootKey := tempTrie.RootKey() - rootNode, err := tempTrie.GetNodeFromKey(rootKey) + _, err = trie.VerifyRangeProof(root, first, []*felt.Felt{records[start].key}, []*felt.Felt{records[start].value}, proof) require.NoError(t, err) - leftNode, err := tempTrie.GetNodeFromKey(rootNode.Left) - require.NoError(t, err) - require.Equal(t, 1, len(sns)) - require.Equal(t, rootKey.Len(), sns[0].Key().Len()) - require.Equal(t, leftNode.HashFromParent(rootKey, rootNode.Left, crypto.Pedersen).String(), sns[0].Node().LeftHash.String()) - require.NotEqual(t, siblingValue.String(), sns[0].Node().RightHash.String()) }) - t.Run("PTP boundary proofs with three key trie", func(t *testing.T) { - tri := build3KeyTrie(t) - rootKey := tri.RootKey() - rootNode, err := tri.GetNodeFromKey(rootKey) - require.NoError(t, err) + t.Run("1 key trie", func(t *testing.T) { + t.Parallel() - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - zeroLeafValue := new(felt.Felt).SetUint64(4) - oneLeafValue := new(felt.Felt).SetUint64(5) - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - bProofs, err := trie.GetBoundaryProofs(&zeroLeafkey, &twoLeafkey, tri) + tr, records := build1KeyTrie(t) + root, err := tr.Root() require.NoError(t, err) - // Test 1 - leftProofPath, err := trie.ProofToPath(bProofs[0], &zeroLeafkey, crypto.Pedersen) - require.Equal(t, 2, len(leftProofPath)) - require.NoError(t, err) - left, err := tri.GetNodeFromKey(rootNode.Left) - require.NoError(t, err) - right, err := tri.GetNodeFromKey(rootNode.Right) + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(&felt.Zero, records[0].key, proof) require.NoError(t, err) - require.Equal(t, rootKey, leftProofPath[0].Key()) - require.Equal(t, left.HashFromParent(rootKey, rootNode.Left, crypto.Pedersen).String(), leftProofPath[0].Node().LeftHash.String()) - require.Equal(t, right.HashFromParent(rootKey, rootNode.Right, crypto.Pedersen).String(), leftProofPath[0].Node().RightHash.String()) - require.Equal(t, rootNode.Left, leftProofPath[1].Key()) - require.Equal(t, zeroLeafValue.String(), leftProofPath[1].Node().LeftHash.String()) - require.Equal(t, oneLeafValue.String(), leftProofPath[1].Node().RightHash.String()) - - // Test 2 - rightProofPath, err := trie.ProofToPath(bProofs[1], &twoLeafkey, crypto.Pedersen) - require.Equal(t, 1, len(rightProofPath)) + + _, err = trie.VerifyRangeProof(root, records[0].key, []*felt.Felt{records[0].key}, []*felt.Felt{records[0].value}, proof) require.NoError(t, err) - require.Equal(t, rootKey, rightProofPath[0].Key()) - require.NotEqual(t, rootNode.Right, rightProofPath[0].Node().Right) - require.NotEqual(t, uint8(0), rightProofPath[0].Node().Right) - require.Equal(t, right.HashFromParent(rootKey, rootNode.Right, crypto.Pedersen).String(), rightProofPath[0].Node().RightHash.String()) }) } -func TestBuildTrie(t *testing.T) { - t.Run("Simple binary trie proof to path", func(t *testing.T) { - // Node (edge path 249) - // / \ - // Node (binary) 0x6 (leaf) - // / \ - // 0x4 0x5 (leaf, leaf) - - tri := build3KeyTrie(t) - rootKey := tri.RootKey() - rootCommitment, err := tri.Root() - require.NoError(t, err) - rootNode, err := tri.GetNodeFromKey(rootKey) - require.NoError(t, err) - leftNode, err := tri.GetNodeFromKey(rootNode.Left) - require.NoError(t, err) - leftleftNode, err := tri.GetNodeFromKey(leftNode.Left) - require.NoError(t, err) - leftrightNode, err := tri.GetNodeFromKey(leftNode.Right) - require.NoError(t, err) +// TestAllElementsProof tests the range proof with all elements and nil proof. +func TestAllElementsRangeProof(t *testing.T) { + t.Parallel() - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - bProofs, err := trie.GetBoundaryProofs(&zeroLeafkey, &twoLeafkey, tri) - require.NoError(t, err) - - leftProof, err := trie.ProofToPath(bProofs[0], &zeroLeafkey, crypto.Pedersen) - require.NoError(t, err) + n := 1000 + tr, records := randomTrie(t, n) + root, err := tr.Root() + require.NoError(t, err) - rightProof, err := trie.ProofToPath(bProofs[1], &twoLeafkey, crypto.Pedersen) - require.NoError(t, err) + keys := make([]*felt.Felt, n) + values := make([]*felt.Felt, n) + for i, record := range records { + keys[i] = record.key + values[i] = record.value + } - keys := []*felt.Felt{new(felt.Felt).SetUint64(1)} - values := []*felt.Felt{new(felt.Felt).SetUint64(5)} - builtTrie, err := trie.BuildTrie(leftProof, rightProof, keys, values) - require.NoError(t, err) + _, err = trie.VerifyRangeProof(root, nil, keys, values, nil) + require.NoError(t, err) - builtRootKey := builtTrie.RootKey() - builtRootNode, err := builtTrie.GetNodeFromKey(builtRootKey) - require.NoError(t, err) - builtLeftNode, err := builtTrie.GetNodeFromKey(builtRootNode.Left) - require.NoError(t, err) - builtLeftRightNode, err := builtTrie.GetNodeFromKey(builtLeftNode.Right) - require.NoError(t, err) + // Should also work with proof + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(records[0].key, records[n-1].key, proof) + require.NoError(t, err) - // Assert the structure / keys correct - require.Equal(t, rootKey, builtRootKey) - require.Equal(t, rootNode.Left, builtRootNode.Left, "left fail") - require.Equal(t, leftrightNode.Right, builtLeftRightNode.Right, "right fail") - require.Equal(t, uint8(0), builtRootNode.Right.Len(), "right fail") - require.Equal(t, uint8(0), builtLeftNode.Left.Len(), "left left fail") + _, err = trie.VerifyRangeProof(root, keys[0], keys, values, proof) + require.NoError(t, err) +} - // Assert the leaf nodes have the correct values - require.Equal(t, leftleftNode.Value.String(), builtLeftNode.LeftHash.String(), "should be 0x4") - require.Equal(t, leftrightNode.Value.String(), builtLeftRightNode.Value.String(), "should be 0x5") +// TestSingleSideRangeProof tests the range proof starting with zero. +func TestSingleSideRangeProof(t *testing.T) { + t.Parallel() - // Given the above two asserts pass, we should be able to reconstruct the correct commitment - reconstructedRootCommitment, err := builtTrie.Root() - require.NoError(t, err) - require.Equal(t, rootCommitment.String(), reconstructedRootCommitment.String(), "root commitment not equal") - }) -} + tr, records := randomTrie(t, 1000) + root, err := tr.Root() + require.NoError(t, err) -func TestVerifyRangeProof(t *testing.T) { - t.Run("VPR two proofs, single key trie", func(t *testing.T) { - // Node (edge path 249) - // / \ - // Node (binary) 0x6 (leaf) - // / \ - // 0x4 0x5 (leaf, leaf) - - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - - tri := build3KeyTrie(t) - keys := []*felt.Felt{new(felt.Felt).SetUint64(1)} - values := []*felt.Felt{new(felt.Felt).SetUint64(5)} - proofKeys := [2]*trie.Key{&zeroLeafkey, &twoLeafkey} - proofValues := [2]*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(6)} - rootCommitment, err := tri.Root() - require.NoError(t, err) - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) + for i := 0; i < len(records); i += 100 { + proof := trie.NewProofNodeSet() + err := tr.GetRangeProof(&felt.Zero, records[i].key, proof) require.NoError(t, err) - verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) - require.NoError(t, err) - require.True(t, verif) - }) - t.Run("VPR all keys provided, no proofs needed", func(t *testing.T) { - // Node (edge path 249) - // / \ - // Node (binary) 0x6 (leaf) - // / \ - // 0x4 0x5 (leaf, leaf) - tri := build3KeyTrie(t) - keys := []*felt.Felt{new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)} - values := []*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(5), new(felt.Felt).SetUint64(6)} - proofKeys := [2]*trie.Key{} - proofValues := [2]*felt.Felt{} - proofs := [2][]trie.ProofNode{} - rootCommitment, err := tri.Root() - require.NoError(t, err) - verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) - require.NoError(t, err) - require.True(t, verif) - }) + keys := make([]*felt.Felt, i+1) + values := make([]*felt.Felt, i+1) + for j := 0; j < i+1; j++ { + keys[j] = records[j].key + values[j] = records[j].value + } - t.Run("VPR left proof, all right keys", func(t *testing.T) { - // Node (edge path 249) - // / \ - // Node (binary) 0x6 (leaf) - // / \ - // 0x4 0x5 (leaf, leaf) - - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - - tri := build3KeyTrie(t) - keys := []*felt.Felt{new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)} - values := []*felt.Felt{new(felt.Felt).SetUint64(5), new(felt.Felt).SetUint64(6)} - proofKeys := [2]*trie.Key{&zeroLeafkey} - proofValues := [2]*felt.Felt{new(felt.Felt).SetUint64(4)} - leftProof, err := trie.GetProof(proofKeys[0], tri) - require.NoError(t, err) - proofs := [2][]trie.ProofNode{leftProof} - rootCommitment, err := tri.Root() - require.NoError(t, err) - verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) + _, err = trie.VerifyRangeProof(root, &felt.Zero, keys, values, proof) require.NoError(t, err) - require.True(t, verif) - }) + } +} - t.Run("VPR right proof, all left keys", func(t *testing.T) { - // Node (edge path 249) - // / \ - // Node (binary) 0x6 (leaf) - // / \ - // 0x4 0x5 (leaf, leaf) - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - - tri := build3KeyTrie(t) - keys := []*felt.Felt{new(felt.Felt).SetUint64(0), new(felt.Felt).SetUint64(1)} - values := []*felt.Felt{new(felt.Felt).SetUint64(4), new(felt.Felt).SetUint64(5)} - proofKeys := [2]*trie.Key{nil, &twoLeafkey} - proofValues := [2]*felt.Felt{nil, new(felt.Felt).SetUint64(6)} - rightProof, err := trie.GetProof(proofKeys[1], tri) - require.NoError(t, err) - proofs := [2][]trie.ProofNode{nil, rightProof} - rootCommitment, err := tri.Root() - require.NoError(t, err) - verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) - require.NoError(t, err) - require.True(t, verif) - }) +func TestGappedRangeProof(t *testing.T) { + t.Parallel() + t.Skip("gapped keys will sometimes succeed, the current proof format is not able to handle this") - t.Run("VPR left proof, all inner keys, right proof with non-set key", func(t *testing.T) { - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + tr, records := nonRandomTrie(t, 5) + root, err := tr.Root() + require.NoError(t, err) - threeFeltBytes := new(felt.Felt).SetUint64(3).Bytes() - threeLeafkey := trie.NewKey(251, threeFeltBytes[:]) + first, last := 1, 4 + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(records[first].key, records[last].key, proof) + require.NoError(t, err) - tri := build4KeyTrie(t) - keys := []*felt.Felt{new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(2)} - values := []*felt.Felt{new(felt.Felt).SetUint64(5), new(felt.Felt).SetUint64(6)} - proofKeys := [2]*trie.Key{&zeroLeafkey, &threeLeafkey} - proofValues := [2]*felt.Felt{new(felt.Felt).SetUint64(4), nil} - leftProof, err := trie.GetProof(proofKeys[0], tri) - require.NoError(t, err) - rightProof, err := trie.GetProof(proofKeys[1], tri) - require.NoError(t, err) + keys := []*felt.Felt{} + values := []*felt.Felt{} + for i := first; i <= last; i++ { + if i == (first+last)/2 { + continue + } - proofs := [2][]trie.ProofNode{leftProof, rightProof} - rootCommitment, err := tri.Root() - require.NoError(t, err) + keys = append(keys, records[i].key) + values = append(values, records[i].value) + } - verif, err := trie.VerifyRangeProof(rootCommitment, keys, values, proofKeys, proofValues, proofs, crypto.Pedersen) - require.NoError(t, err) - require.True(t, verif) - }) + _, err = trie.VerifyRangeProof(root, records[first].key, keys, values, proof) + require.Error(t, err) } -func TestMergeProofPaths(t *testing.T) { - t.Run("3Key Trie no duplicates and all values exist in merged path", func(t *testing.T) { - tri := build3KeyTrie(t) - - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) +func TestEmptyRangeProof(t *testing.T) { + t.Parallel() - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - - proofKeys := [2]*trie.Key{&zeroLeafkey, &twoLeafkey} + tr, records := randomTrie(t, 1000) + root, err := tr.Root() + require.NoError(t, err) - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) + cases := []struct { + pos int + err bool + }{ + {len(records) - 1, false}, + {500, true}, + } - mergedProofs, _, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) + for _, c := range cases { + proof := trie.NewProofNodeSet() + first := incrementFelt(records[c.pos].key) + err = tr.GetRangeProof(first, first, proof) require.NoError(t, err) - require.True(t, containsAll(mergedProofs, proofs[0])) - require.True(t, containsAll(mergedProofs, proofs[1])) - require.True(t, noDuplicates(mergedProofs)) - }) - - t.Run("4Key Trie two common ancestors", func(t *testing.T) { - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) + _, err := trie.VerifyRangeProof(root, first, nil, nil, proof) + if c.err { + require.Error(t, err) + } else { + require.NoError(t, err) + } + } +} - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) +func TestHasRightElement(t *testing.T) { + t.Parallel() - tri := build4KeyTrie(t) + tr, records := randomTrie(t, 500) + root, err := tr.Root() + require.NoError(t, err) - proofKeys := [2]*trie.Key{&zeroLeafkey, &twoLeafkey} + cases := []struct { + start int + end int + hasMore bool + }{ + {-1, 1, true}, // single element with non-existent left proof + {0, 1, true}, // single element with existent left proof + {0, 100, true}, // start to middle + {50, 100, true}, // middle only + {50, len(records), false}, // middle to end + {len(records) - 1, len(records), false}, // Single last element with two existent proofs(point to same key) + {0, len(records), false}, // The whole set with existent left proof + {-1, len(records), false}, // The whole set with non-existent left proof + } - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) + for _, c := range cases { + var ( + first *felt.Felt + start = c.start + end = c.end + proof = trie.NewProofNodeSet() + ) + if start == -1 { + first = &felt.Zero + start = 0 + } else { + first = records[start].key + } - mergedProofs, _, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) + err := tr.GetRangeProof(first, records[end-1].key, proof) require.NoError(t, err) - require.True(t, containsAll(mergedProofs, proofs[0])) - require.True(t, containsAll(mergedProofs, proofs[1])) - require.True(t, noDuplicates(mergedProofs)) - }) + keys := []*felt.Felt{} + values := []*felt.Felt{} + for i := start; i < end; i++ { + keys = append(keys, records[i].key) + values = append(values, records[i].value) + } - t.Run("Trie 4Key one ancestor", func(t *testing.T) { - tri := build4KeyTrie(t) - fourFeltBytes := new(felt.Felt).SetUint64(4).Bytes() - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + hasMore, err := trie.VerifyRangeProof(root, first, keys, values, proof) + require.NoError(t, err) + require.Equal(t, c.hasMore, hasMore) + } +} - fourLeafkey := trie.NewKey(251, fourFeltBytes[:]) - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) +// TestBadRangeProof generates random bad proof scenarios and verifies that the proof is invalid. +func TestBadRangeProof(t *testing.T) { + t.Parallel() - proofKeys := [2]*trie.Key{&zeroLeafkey, &fourLeafkey} + tr, records := randomTrie(t, 1000) + root, err := tr.Root() + require.NoError(t, err) - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) + for i := 0; i < 100; i++ { + start := rand.Intn(len(records)) + end := rand.Intn(len(records)-start) + start + 1 - mergedProofs, _, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) + proof := trie.NewProofNodeSet() + err := tr.GetRangeProof(records[start].key, records[end-1].key, proof) require.NoError(t, err) - require.True(t, containsAll(mergedProofs, proofs[0])) - require.True(t, containsAll(mergedProofs, proofs[1])) - require.True(t, noDuplicates(mergedProofs)) - }) - - t.Run("Empty proof path", func(t *testing.T) { - tri := build4KeyTrie(t) - fourFeltBytes := new(felt.Felt).SetUint64(4).Bytes() - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() + keys := []*felt.Felt{} + values := []*felt.Felt{} + for j := start; j < end; j++ { + keys = append(keys, records[j].key) + values = append(values, records[j].value) + } - fourLeafkey := trie.NewKey(251, fourFeltBytes[:]) - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + first := keys[0] + testCase := rand.Intn(5) + + index := rand.Intn(end - start) + switch testCase { + case 0: // modified key + keys[index] = new(felt.Felt).SetUint64(rand.Uint64()) + case 1: // modified value + values[index] = new(felt.Felt).SetUint64(rand.Uint64()) + case 2: // out of order + index2 := rand.Intn(end - start) + if index2 == index { + continue + } + keys[index], keys[index2] = keys[index2], keys[index] + values[index], values[index2] = values[index2], values[index] + case 3: // set random key to empty + keys[index] = &felt.Zero + case 4: // set random value to empty + values[index] = &felt.Zero + // TODO(weiihann): gapped proof will fail sometimes + // case 5: // gapped + // if end-start < 100 || index == 0 || index == end-start-1 { + // continue + // } + // keys = append(keys[:index], keys[index+1:]...) + // values = append(values[:index], values[index+1:]...) + } + _, err = trie.VerifyRangeProof(root, first, keys, values, proof) + if err == nil { + t.Fatalf("expected error for test case %d, index %d, start %d, end %d", testCase, index, start, end) + } + } +} - proofKeys := [2]*trie.Key{&zeroLeafkey, &fourLeafkey} +func BenchmarkProve(b *testing.B) { + tr, records := randomTrie(b, 1000) + b.ResetTimer() + for i := 0; i < b.N; i++ { + proof := trie.NewProofNodeSet() + key := records[i%len(records)].key + if err := tr.Prove(key, proof); err != nil { + b.Fatal(err) + } + } +} - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) +func BenchmarkVerifyProof(b *testing.B) { + tr, records := randomTrie(b, 1000) + root, err := tr.Root() + require.NoError(b, err) - emptyPath := []trie.ProofNode{} + var proofs []*trie.ProofNodeSet + for _, record := range records { + proof := trie.NewProofNodeSet() + if err := tr.Prove(record.key, proof); err != nil { + b.Fatal(err) + } + proofs = append(proofs, proof) + } - _, _, err = trie.MergeProofPaths(proofs[0], emptyPath, crypto.Pedersen) - require.Error(t, err) - }) + b.ResetTimer() + for i := 0; i < b.N; i++ { + index := i % len(records) + if _, err := trie.VerifyProof(root, records[index].key, proofs[index], crypto.Pedersen); err != nil { + b.Fatal(err) + } + } +} - t.Run("Root of the proof paths are different", func(t *testing.T) { - tri := build4KeyTrie(t) - fourFeltBytes := new(felt.Felt).SetUint64(4).Bytes() - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() +func BenchmarkVerifyRangeProof(b *testing.B) { + tr, records := randomTrie(b, 1000) + root, err := tr.Root() + require.NoError(b, err) - fourLeafkey := trie.NewKey(251, fourFeltBytes[:]) - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) + start := 2 + end := start + 500 - proofKeys := [2]*trie.Key{&zeroLeafkey, &fourLeafkey} + proof := trie.NewProofNodeSet() + err = tr.GetRangeProof(records[start].key, records[end-1].key, proof) + require.NoError(b, err) - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) + keys := make([]*felt.Felt, end-start) + values := make([]*felt.Felt, end-start) + for i := start; i < end; i++ { + keys[i-start] = records[i].key + values[i-start] = records[i].value + } - _, _, err = trie.MergeProofPaths(proofs[0], proofs[1][1:], crypto.Pedersen) - require.Error(t, err) - }) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := trie.VerifyRangeProof(root, keys[0], keys, values, proof) + require.NoError(b, err) + } } -func TestSplitProofPaths(t *testing.T) { - t.Run("3Key Trie retrieved right and left proofs are same with the merged ones", func(t *testing.T) { - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - - tri := build3KeyTrie(t) - proofKeys := [2]*trie.Key{&zeroLeafkey, &twoLeafkey} +func buildTrie(t *testing.T, records []*keyValue) *trie.Trie { + if len(records) == 0 { + t.Fatal("records must have at least one element") + } - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) - mergedProofs, rootHash, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) - require.NoError(t, err) + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) - leftSplit, rightSplit, err := trie.SplitProofPath(mergedProofs, rootHash, crypto.Pedersen) + for _, record := range records { + _, err = tempTrie.Put(record.key, record.value) require.NoError(t, err) + } - require.True(t, isSameProofPath(leftSplit, proofs[0])) - require.True(t, isSameProofPath(rightSplit, proofs[1])) - }) + require.NoError(t, tempTrie.Commit()) - t.Run("4Key Trie two common ancestors retrieved right and left proofs are same with the merged ones", func(t *testing.T) { - tri := build4KeyTrie(t) + return tempTrie +} - twoFeltBytes := new(felt.Felt).SetUint64(2).Bytes() - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() +func build1KeyTrie(t *testing.T) (*trie.Trie, []*keyValue) { + return nonRandomTrie(t, 1) +} - twoLeafkey := trie.NewKey(251, twoFeltBytes[:]) - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) +func buildSimpleTrie(t *testing.T) (*trie.Trie, []*keyValue) { + // (250, 0, x1) edge + // | + // (0,0,x1) binary + // / \ + // (2) (3) + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: new(felt.Felt).SetUint64(2)}, + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(3)}, + } - proofKeys := [2]*trie.Key{&zeroLeafkey, &twoLeafkey} + return buildTrie(t, records), records +} - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) +func buildSimpleBinaryRootTrie(t *testing.T) (*trie.Trie, []*keyValue) { + // PF + // (0, 0, x) + // / \ + // (250, 0, cc) (250, 11111.., dd) + // | | + // (cc) (dd) - mergedProofs, rootHash, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) - require.NoError(t, err) + // JUNO + // (0, 0, x) + // / \ + // (251, 0, cc) (251, 11111.., dd) + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: utils.HexToFelt(t, "0xcc")}, + {key: utils.HexToFelt(t, "0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), value: utils.HexToFelt(t, "0xdd")}, + } + return buildTrie(t, records), records +} - leftSplit, rightSplit, err := trie.SplitProofPath(mergedProofs, rootHash, crypto.Pedersen) - require.NoError(t, err) +//nolint:dupl +func buildSimpleDoubleBinaryTrie(t *testing.T) (*trie.Trie, []*keyValue) { + // (249,0,x3) // Edge + // | + // (0, 0, x3) // Binary + // / \ + // (0,0,x1) // B (1, 1, 5) // Edge leaf + // / \ | + // (2) (3) (5) + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: new(felt.Felt).SetUint64(2)}, + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(3)}, + {key: new(felt.Felt).SetUint64(3), value: new(felt.Felt).SetUint64(5)}, + } + return buildTrie(t, records), records +} - require.True(t, isSameProofPath(leftSplit, proofs[0])) - require.True(t, isSameProofPath(rightSplit, proofs[1])) - }) +//nolint:dupl +func build3KeyTrie(t *testing.T) (*trie.Trie, []*keyValue) { + // Starknet + // -------- + // + // Edge + // | + // Binary with len 249 parent + // / \ + // Binary (250) Edge with len 250 + // / \ / + // 0x4 0x5 0x6 child - t.Run("4Key Trie one common ancestor retrieved right and left proofs are same with the merged ones", func(t *testing.T) { - tri := build4KeyTrie(t) - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - fourFeltBytes := new(felt.Felt).SetUint64(4).Bytes() - fourLeafkey := trie.NewKey(251, fourFeltBytes[:]) + // Juno + // ---- + // + // Node (path 249) + // / \ + // Node (binary) \ + // / \ / + // 0x4 0x5 0x6 + records := []*keyValue{ + {key: new(felt.Felt).SetUint64(0), value: new(felt.Felt).SetUint64(4)}, + {key: new(felt.Felt).SetUint64(1), value: new(felt.Felt).SetUint64(5)}, + {key: new(felt.Felt).SetUint64(2), value: new(felt.Felt).SetUint64(6)}, + } - proofKeys := [2]*trie.Key{&zeroLeafkey, &fourLeafkey} + return buildTrie(t, records), records +} - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) +func nonRandomTrie(t *testing.T, numKeys int) (*trie.Trie, []*keyValue) { + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) - mergedProofs, rootHash, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) - require.NoError(t, err) + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) - leftSplit, rightSplit, err := trie.SplitProofPath(mergedProofs, rootHash, crypto.Pedersen) + records := make([]*keyValue, numKeys) + for i := 1; i < numKeys+1; i++ { + key := new(felt.Felt).SetUint64(uint64(i)) + records[i-1] = &keyValue{key: key, value: key} + _, err := tempTrie.Put(key, key) require.NoError(t, err) + } - require.True(t, isSameProofPath(leftSplit, proofs[0])) - require.True(t, isSameProofPath(rightSplit, proofs[1])) + sort.Slice(records, func(i, j int) bool { + return records[i].key.Cmp(records[j].key) < 0 }) - t.Run("4Key Trie reversed merge path", func(t *testing.T) { - tri := build4KeyTrie(t) - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - fourFeltBytes := new(felt.Felt).SetUint64(4).Bytes() - fourLeafkey := trie.NewKey(251, fourFeltBytes[:]) - - proofKeys := [2]*trie.Key{&zeroLeafkey, &fourLeafkey} - - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) - - mergedProofs, rootHash, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) - require.NoError(t, err) - - for i := 0; i < len(mergedProofs)/2; i++ { - j := len(mergedProofs) - 1 - i - mergedProofs[i], mergedProofs[j] = mergedProofs[j], mergedProofs[i] - } - - leftSplit, rightSplit, err := trie.SplitProofPath(mergedProofs, rootHash, crypto.Pedersen) - require.NoError(t, err) + require.NoError(t, tempTrie.Commit()) - require.True(t, isSameProofPath(leftSplit, proofs[0])) - require.True(t, isSameProofPath(rightSplit, proofs[1])) - }) + return tempTrie, records +} - t.Run("Roothash does not exist", func(t *testing.T) { - tri := build4KeyTrie(t) - zeroFeltBytes := new(felt.Felt).SetUint64(0).Bytes() - zeroLeafkey := trie.NewKey(251, zeroFeltBytes[:]) - fourFeltBytes := new(felt.Felt).SetUint64(4).Bytes() - fourLeafkey := trie.NewKey(251, fourFeltBytes[:]) +func randomTrie(t testing.TB, n int) (*trie.Trie, []*keyValue) { + rrand := rand.New(rand.NewSource(3)) - proofKeys := [2]*trie.Key{&zeroLeafkey, &fourLeafkey} + memdb := pebble.NewMemTest(t) + txn, err := memdb.NewTransaction(true) + require.NoError(t, err) - proofs, err := trie.GetBoundaryProofs(proofKeys[0], proofKeys[1], tri) - require.NoError(t, err) + tempTrie, err := trie.NewTriePedersen(trie.NewStorage(txn, []byte{0}), 251) + require.NoError(t, err) - mergedProofs, _, err := trie.MergeProofPaths(proofs[0], proofs[1], crypto.Pedersen) + records := make([]*keyValue, n) + for i := 0; i < n; i++ { + key := new(felt.Felt).SetUint64(uint64(rrand.Uint32() + 1)) + records[i] = &keyValue{key: key, value: key} + _, err := tempTrie.Put(key, key) require.NoError(t, err) + } - rootHashFalse := new(felt.Felt).SetUint64(0) + require.NoError(t, tempTrie.Commit()) - _, _, err = trie.SplitProofPath(mergedProofs, rootHashFalse, crypto.Pedersen) - require.Error(t, err) + // Sort records by key + sort.Slice(records, func(i, j int) bool { + return records[i].key.Cmp(records[j].key) < 0 }) - t.Run("Two splits in the merged path", func(t *testing.T) { - p1 := newBinaryProofNode() - p2 := newBinaryProofNode() - p3 := newBinaryProofNode() - p4 := newBinaryProofNode() - p5 := newBinaryProofNode() - - p4.LeftHash = new(felt.Felt).SetUint64(3) - p2.RightHash = new(felt.Felt).SetUint64(4) - - p3.RightHash = p5.Hash(crypto.Pedersen) - p3.LeftHash = p4.Hash(crypto.Pedersen) - p1.RightHash = p3.Hash(crypto.Pedersen) - p1.LeftHash = p2.Hash(crypto.Pedersen) - - mergedProofs := []trie.ProofNode{p1, p2, p3, p4, p5} - rootHash := p1.Hash(crypto.Pedersen) + return tempTrie, records +} - _, _, err := trie.SplitProofPath(mergedProofs, rootHash, crypto.Pedersen) - require.Error(t, err) - }) +func decrementFelt(f *felt.Felt) *felt.Felt { + return new(felt.Felt).Sub(f, new(felt.Felt).SetUint64(1)) +} - t.Run("Duplicate nodes in the merged path", func(t *testing.T) { - p1 := newBinaryProofNode() - p2 := newBinaryProofNode() - p3 := newBinaryProofNode() - p4 := newBinaryProofNode() - p5 := newBinaryProofNode() +func incrementFelt(f *felt.Felt) *felt.Felt { + return new(felt.Felt).Add(f, new(felt.Felt).SetUint64(1)) +} - p3.RightHash = p5.Hash(crypto.Pedersen) - p3.LeftHash = p4.Hash(crypto.Pedersen) - p1.RightHash = p3.Hash(crypto.Pedersen) - p1.LeftHash = p2.Hash(crypto.Pedersen) +type testKey struct { + name string + key *felt.Felt + expected *felt.Felt +} - mergedProofs := []trie.ProofNode{p1, p2, p3, p4, p5} - rootHash := p1.Hash(crypto.Pedersen) +type testTrie struct { + name string + buildFn func(*testing.T) (*trie.Trie, []*keyValue) + testKeys []testKey +} - _, _, err := trie.SplitProofPath(mergedProofs, rootHash, crypto.Pedersen) - require.Error(t, err) - }) +type keyValue struct { + key *felt.Felt + value *felt.Felt } diff --git a/core/trie/trie.go b/core/trie/trie.go index ff978c8709..0dd6f4c77c 100644 --- a/core/trie/trie.go +++ b/core/trie/trie.go @@ -11,8 +11,11 @@ import ( "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/utils" ) +const globalTrieHeight = 251 // TODO(weiihann): this is declared in core also, should be moved to a common place + type hashFunc func(*felt.Felt, *felt.Felt) *felt.Felt // Trie is a dense Merkle Patricia Trie (i.e., all internal nodes have two children). @@ -95,31 +98,8 @@ func RunOnTempTriePoseidon(height uint8, do func(*Trie) error) error { // feltToKey Converts a key, given in felt, to a trie.Key which when followed on a [Trie], // leads to the corresponding [Node] -func (t *Trie) feltToKey(k *felt.Felt) Key { - kBytes := k.Bytes() - return NewKey(t.height, kBytes[:]) -} - -// findCommonKey finds the set of common MSB bits in two key bitsets. -func findCommonKey(longerKey, shorterKey *Key) (Key, bool) { - divergentBit := findDivergentBit(longerKey, shorterKey) - commonKey := *shorterKey - commonKey.DeleteLSB(shorterKey.Len() - divergentBit + 1) - return commonKey, divergentBit == shorterKey.Len()+1 -} - -func findDivergentBit(longerKey, shorterKey *Key) uint8 { - divergentBit := uint8(0) - for divergentBit <= shorterKey.Len() && - longerKey.Test(longerKey.Len()-divergentBit) == shorterKey.Test(shorterKey.Len()-divergentBit) { - divergentBit++ - } - return divergentBit -} - -func isSubset(longerKey, shorterKey *Key) bool { - divergentBit := findDivergentBit(longerKey, shorterKey) - return divergentBit == shorterKey.Len()+1 +func (t *Trie) FeltToKey(k *felt.Felt) Key { + return FeltToKey(t.height, k) } // path returns the path as mentioned in the [specification] for commitment calculations. @@ -145,14 +125,97 @@ func (sn *StorageNode) Key() *Key { return sn.key } -func (sn *StorageNode) Node() *Node { - return sn.node +func (sn *StorageNode) Value() *felt.Felt { + return sn.node.Value +} + +func (sn *StorageNode) String() string { + return fmt.Sprintf("StorageNode{key: %s, node: %s}", sn.key, sn.node) +} + +func (sn *StorageNode) Update(other *StorageNode) error { + // First validate all fields for conflicts + if sn.key != nil && other.key != nil && !sn.key.Equal(NilKey) && !other.key.Equal(NilKey) { + if !sn.key.Equal(other.key) { + return fmt.Errorf("keys do not match: %s != %s", sn.key, other.key) + } + } + + // Validate node updates + if sn.node != nil && other.node != nil { + if err := sn.node.Update(other.node); err != nil { + return err + } + } + + // After validation, perform update + if other.key != nil && !other.key.Equal(NilKey) { + sn.key = other.key + } + + return nil } func NewStorageNode(key *Key, node *Node) *StorageNode { return &StorageNode{key: key, node: node} } +// NewPartialStorageNode creates a new StorageNode with a given key and value, +// where the right and left children are nil. +func NewPartialStorageNode(key *Key, value *felt.Felt) *StorageNode { + return &StorageNode{ + key: key, + node: &Node{ + Value: value, + Left: NilKey, + Right: NilKey, + }, + } +} + +// StorageNodeSet wraps OrderedSet to provide specific functionality for StorageNodes +type StorageNodeSet struct { + set *utils.OrderedSet[Key, *StorageNode] +} + +func NewStorageNodeSet() *StorageNodeSet { + return &StorageNodeSet{ + set: utils.NewOrderedSet[Key, *StorageNode](), + } +} + +func (s *StorageNodeSet) Get(key Key) (*StorageNode, bool) { + return s.set.Get(key) +} + +// Put adds a new StorageNode or updates an existing one. +func (s *StorageNodeSet) Put(key Key, node *StorageNode) error { + if node == nil { + return fmt.Errorf("cannot put nil node") + } + + // If key exists, update the node + if existingNode, exists := s.set.Get(key); exists { + if err := existingNode.Update(node); err != nil { + return fmt.Errorf("failed to update node for key %v: %w", key, err) + } + return nil + } + + // Add new node if key doesn't exist + s.set.Put(key, node) + return nil +} + +// List returns the list of StorageNodes in the set. +func (s *StorageNodeSet) List() []*StorageNode { + return s.set.List() +} + +func (s *StorageNodeSet) Size() int { + return s.set.Size() +} + // nodesFromRoot enumerates the set of [Node] objects that need to be traversed from the root // of the Trie to the node which is given by the key. // The [storageNode]s are returned in descending order beginning with the root. @@ -180,7 +243,7 @@ func (t *Trie) nodesFromRoot(key *Key) ([]StorageNode, error) { return nodes, nil } - if key.Test(key.Len() - cur.Len() - 1) { + if key.IsBitSet(key.Len() - cur.Len() - 1) { cur = node.Right } else { cur = node.Left @@ -192,7 +255,7 @@ func (t *Trie) nodesFromRoot(key *Key) ([]StorageNode, error) { // Get the corresponding `value` for a `key` func (t *Trie) Get(key *felt.Felt) (*felt.Felt, error) { - storageKey := t.feltToKey(key) + storageKey := t.FeltToKey(key) value, err := t.storage.Get(&storageKey) if err != nil { if errors.Is(err, db.ErrKeyNotFound) { @@ -261,6 +324,7 @@ func (t *Trie) replaceLinkWithNewParent(key *Key, commonKey Key, siblingParent S } } +// TODO(weiihann): not a good idea to couple proof verification logic with trie logic func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode, sibling StorageNode, siblingIsParentProof bool) error { commonKey, _ := findCommonKey(nodeKey, sibling.key) @@ -274,7 +338,7 @@ func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode if err != nil { return err } - if nodeKey.Test(nodeKey.Len() - commonKey.Len() - 1) { + if nodeKey.IsBitSet(nodeKey.Len() - commonKey.Len() - 1) { newParent.Right = nodeKey newParent.RightHash = node.Hash(nodeKey, t.hash) } else { @@ -286,7 +350,7 @@ func (t *Trie) insertOrUpdateValue(nodeKey *Key, node *Node, nodes []StorageNode } t.dirtyNodes = append(t.dirtyNodes, &commonKey) } else { - if nodeKey.Test(nodeKey.Len() - commonKey.Len() - 1) { + if nodeKey.IsBitSet(nodeKey.Len() - commonKey.Len() - 1) { newParent.Left, newParent.Right = sibling.key, nodeKey leftChild, rightChild = sibling.node, node } else { @@ -328,7 +392,7 @@ func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) { } old := felt.Zero - nodeKey := t.feltToKey(key) + nodeKey := t.FeltToKey(key) node := &Node{ Value: value, } @@ -373,13 +437,13 @@ func (t *Trie) Put(key, value *felt.Felt) (*felt.Felt, error) { } // Put updates the corresponding `value` for a `key` -func (t *Trie) PutWithProof(key, value *felt.Felt, lProofPath, rProofPath []StorageNode) (*felt.Felt, error) { +func (t *Trie) PutWithProof(key, value *felt.Felt, proof []*StorageNode) (*felt.Felt, error) { if key.Cmp(t.maxKey) > 0 { return nil, fmt.Errorf("key %s exceeds trie height %d", key, t.height) } old := felt.Zero - nodeKey := t.feltToKey(key) + nodeKey := t.FeltToKey(key) node := &Node{ Value: value, } @@ -417,24 +481,14 @@ func (t *Trie) PutWithProof(key, value *felt.Felt, lProofPath, rProofPath []Stor } // override the sibling to be the parent if it's a proof - parentIsProof, found := false, false - for _, proof := range lProofPath { - if proof.key.Equal(sibling.key) { - sibling = proof + parentIsProof := false + for _, proofNode := range proof { + if proofNode.key.Equal(sibling.key) { + sibling = *proofNode parentIsProof = true - found = true break } } - if !found { - for _, proof := range rProofPath { - if proof.key.Equal(sibling.key) { - sibling = proof - parentIsProof = true - break - } - } - } err := t.insertOrUpdateValue(&nodeKey, node, nodes, sibling, parentIsProof) if err != nil { @@ -445,14 +499,11 @@ func (t *Trie) PutWithProof(key, value *felt.Felt, lProofPath, rProofPath []Stor } // Put updates the corresponding `value` for a `key` -func (t *Trie) PutInner(key *Key, node *Node) (*felt.Felt, error) { +func (t *Trie) PutInner(key *Key, node *Node) error { if err := t.storage.Put(key, node); err != nil { - return nil, err - } - if t.rootKey == nil { - t.setRootKey(key) + return err } - return &felt.Zero, nil + return nil } func (t *Trie) setRootKey(newRootKey *Key) { @@ -461,9 +512,6 @@ func (t *Trie) setRootKey(newRootKey *Key) { } func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo - zeroFeltBytes := new(felt.Felt).Bytes() - nilKey := NewKey(0, zeroFeltBytes[:]) - node, err := t.storage.Get(key) if err != nil { return nil, err @@ -485,9 +533,9 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo } // Update inner proof nodes - if node.Left.Equal(&nilKey) && node.Right.Equal(&nilKey) { // leaf + if node.Left.Equal(NilKey) && node.Right.Equal(NilKey) { // leaf shouldUpdate = false - } else if node.Left.Equal(&nilKey) || node.Right.Equal(&nilKey) { // inner + } else if node.Left.Equal(NilKey) || node.Right.Equal(NilKey) { // inner shouldUpdate = true } if !shouldUpdate { @@ -496,11 +544,11 @@ func (t *Trie) updateValueIfDirty(key *Key) (*Node, error) { //nolint:gocyclo var leftIsProof, rightIsProof bool var leftHash, rightHash *felt.Felt - if node.Left.Equal(&nilKey) { + if node.Left.Equal(NilKey) { // key could be nil but hash cannot be leftIsProof = true leftHash = node.LeftHash } - if node.Right.Equal(&nilKey) { + if node.Right.Equal(NilKey) { rightIsProof = true rightHash = node.RightHash } @@ -698,12 +746,33 @@ func (t *Trie) dump(level int, parentP *Key) { } defer nodePool.Put(root) path := path(t.rootKey, parentP) - fmt.Printf("%sstorage : \"%s\" %d spec: \"%s\" %d bottom: \"%s\" \n", + + left := "" + right := "" + leftHash := "" + rightHash := "" + + if root.Left != nil { + left = root.Left.String() + } + if root.Right != nil { + right = root.Right.String() + } + if root.LeftHash != nil { + leftHash = root.LeftHash.String() + } + if root.RightHash != nil { + rightHash = root.RightHash.String() + } + + fmt.Printf("%skey : \"%s\" path: \"%s\" left: \"%s\" right: \"%s\" LH: \"%s\" RH: \"%s\" value: \"%s\" \n", strings.Repeat("\t", level), t.rootKey.String(), - t.rootKey.Len(), path.String(), - path.Len(), + left, + right, + leftHash, + rightHash, root.Value.String(), ) (&Trie{ diff --git a/core/trie/trie_pkg_test.go b/core/trie/trie_pkg_test.go index 87f1801f78..5426cbcafa 100644 --- a/core/trie/trie_pkg_test.go +++ b/core/trie/trie_pkg_test.go @@ -26,7 +26,7 @@ func TestTrieKeys(t *testing.T) { require.NoError(t, err) assert.Equal(t, val, value, "key-val not match") - assert.Equal(t, tempTrie.feltToKey(key), *tempTrie.rootKey, "root key not match single node's key") + assert.Equal(t, tempTrie.FeltToKey(key), *tempTrie.rootKey, "root key not match single node's key") }) t.Run("put a left then a right node", func(t *testing.T) { @@ -53,8 +53,8 @@ func TestTrieKeys(t *testing.T) { require.NoError(t, err) // Check parent and its left right children - l := tempTrie.feltToKey(leftKey) - r := tempTrie.feltToKey(rightKey) + l := tempTrie.FeltToKey(leftKey) + r := tempTrie.FeltToKey(rightKey) commonKey, isSame := findCommonKey(&l, &r) require.False(t, isSame) @@ -69,8 +69,8 @@ func TestTrieKeys(t *testing.T) { parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) - assert.Equal(t, tempTrie.feltToKey(leftKey), *parentNode.Left) - assert.Equal(t, tempTrie.feltToKey(rightKey), *parentNode.Right) + assert.Equal(t, tempTrie.FeltToKey(leftKey), *parentNode.Left) + assert.Equal(t, tempTrie.FeltToKey(rightKey), *parentNode.Right) }) t.Run("put a right node then a left node", func(t *testing.T) { @@ -96,8 +96,8 @@ func TestTrieKeys(t *testing.T) { require.NoError(t, err) // Check parent and its left right children - l := tempTrie.feltToKey(leftKey) - r := tempTrie.feltToKey(rightKey) + l := tempTrie.FeltToKey(leftKey) + r := tempTrie.FeltToKey(rightKey) commonKey, isSame := findCommonKey(&l, &r) require.False(t, isSame) @@ -108,8 +108,8 @@ func TestTrieKeys(t *testing.T) { parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) - assert.Equal(t, tempTrie.feltToKey(leftKey), *parentNode.Left) - assert.Equal(t, tempTrie.feltToKey(rightKey), *parentNode.Right) + assert.Equal(t, tempTrie.FeltToKey(leftKey), *parentNode.Left) + assert.Equal(t, tempTrie.FeltToKey(rightKey), *parentNode.Right) }) t.Run("Add new key to different branches", func(t *testing.T) { @@ -142,8 +142,8 @@ func TestTrieKeys(t *testing.T) { commonKey := NewKey(250, []byte{0x2}) parentNode, pErr := tempTrie.storage.Get(&commonKey) require.NoError(t, pErr) - assert.Equal(t, tempTrie.feltToKey(leftKey), *parentNode.Left) - assert.Equal(t, tempTrie.feltToKey(newKey), *parentNode.Right) + assert.Equal(t, tempTrie.FeltToKey(leftKey), *parentNode.Left) + assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Right) }) //nolint: dupl t.Run("Add to right branch", func(t *testing.T) { @@ -153,8 +153,8 @@ func TestTrieKeys(t *testing.T) { commonKey := NewKey(250, []byte{0x3}) parentNode, pErr := tempTrie.storage.Get(&commonKey) require.NoError(t, pErr) - assert.Equal(t, tempTrie.feltToKey(newKey), *parentNode.Left) - assert.Equal(t, tempTrie.feltToKey(rightKey), *parentNode.Right) + assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Left) + assert.Equal(t, tempTrie.FeltToKey(rightKey), *parentNode.Right) }) t.Run("Add new node as parent sibling", func(t *testing.T) { newKeyNum, err := strconv.ParseUint("000", 2, 64) @@ -170,7 +170,7 @@ func TestTrieKeys(t *testing.T) { parentNode, err := tempTrie.storage.Get(&commonKey) require.NoError(t, err) - assert.Equal(t, tempTrie.feltToKey(newKey), *parentNode.Left) + assert.Equal(t, tempTrie.FeltToKey(newKey), *parentNode.Left) expectRightKey := NewKey(249, []byte{0x1}) @@ -246,8 +246,8 @@ func TestTrieKeysAfterDeleteSubtree(t *testing.T) { rootNode, err := tempTrie.storage.Get(&newRootKey) require.NoError(t, err) - assert.Equal(t, tempTrie.feltToKey(rightKey), *rootNode.Right) - assert.Equal(t, tempTrie.feltToKey(test.expectLeft), *rootNode.Left) + assert.Equal(t, tempTrie.FeltToKey(rightKey), *rootNode.Right) + assert.Equal(t, tempTrie.FeltToKey(test.expectLeft), *rootNode.Left) }) } } diff --git a/db/pebble/db.go b/db/pebble/db.go index 5974edf720..77aed603d7 100644 --- a/db/pebble/db.go +++ b/db/pebble/db.go @@ -60,7 +60,7 @@ func NewMem() (db.DB, error) { } // NewMemTest opens a new in-memory database, panics on error -func NewMemTest(t *testing.T) db.DB { +func NewMemTest(t testing.TB) db.DB { memDB, err := NewMem() if err != nil { t.Fatalf("create in-memory db: %v", err) diff --git a/utils/orderedset.go b/utils/orderedset.go new file mode 100644 index 0000000000..e6e7e2d948 --- /dev/null +++ b/utils/orderedset.go @@ -0,0 +1,67 @@ +package utils + +import ( + "sync" +) + +// OrderedSet is a thread-safe data structure that maintains both uniqueness and insertion order of elements. +// It combines the benefits of both maps and slices: +// - Uses a map for O(1) lookups and to ensure element uniqueness +// - Uses a slice to maintain insertion order and enable ordered iteration +// The data structure is safe for concurrent access through the use of a read-write mutex. +type OrderedSet[K comparable, V any] struct { + itemPos map[K]int // position of the node in the list + items []V + size int + lock sync.RWMutex +} + +func NewOrderedSet[K comparable, V any]() *OrderedSet[K, V] { + return &OrderedSet[K, V]{ + itemPos: make(map[K]int), + } +} + +func (ps *OrderedSet[K, V]) Put(key K, value V) { + ps.lock.Lock() + defer ps.lock.Unlock() + + // Update existing entry + if pos, exists := ps.itemPos[key]; exists { + ps.items[pos] = value + return + } + + // Insert new entry + ps.itemPos[key] = len(ps.items) + ps.items = append(ps.items, value) + ps.size++ +} + +func (ps *OrderedSet[K, V]) Get(key K) (V, bool) { + ps.lock.RLock() + defer ps.lock.RUnlock() + + if pos, ok := ps.itemPos[key]; ok { + return ps.items[pos], true + } + var zero V + return zero, false +} + +func (ps *OrderedSet[K, V]) Size() int { + ps.lock.RLock() + defer ps.lock.RUnlock() + + return ps.size +} + +// List returns a shallow copy of the proof set's value list. +func (ps *OrderedSet[K, V]) List() []V { + ps.lock.RLock() + defer ps.lock.RUnlock() + + values := make([]V, len(ps.items)) + copy(values, ps.items) + return values +} From 76873600e9cd8b0c01f3ff77e971ba8df3a6b840 Mon Sep 17 00:00:00 2001 From: Ng Wei Han <47109095+weiihann@users.noreply.github.com> Date: Thu, 12 Dec 2024 18:54:32 +0800 Subject: [PATCH 10/10] Remove size in OrderedSet (#2319) --- utils/orderedset.go | 46 ++++++++++++++++++++++----------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/utils/orderedset.go b/utils/orderedset.go index e6e7e2d948..1fd9ed3aef 100644 --- a/utils/orderedset.go +++ b/utils/orderedset.go @@ -10,9 +10,8 @@ import ( // - Uses a slice to maintain insertion order and enable ordered iteration // The data structure is safe for concurrent access through the use of a read-write mutex. type OrderedSet[K comparable, V any] struct { - itemPos map[K]int // position of the node in the list + itemPos map[K]int // position of the item in the list items []V - size int lock sync.RWMutex } @@ -22,46 +21,45 @@ func NewOrderedSet[K comparable, V any]() *OrderedSet[K, V] { } } -func (ps *OrderedSet[K, V]) Put(key K, value V) { - ps.lock.Lock() - defer ps.lock.Unlock() +func (o *OrderedSet[K, V]) Put(key K, value V) { + o.lock.Lock() + defer o.lock.Unlock() // Update existing entry - if pos, exists := ps.itemPos[key]; exists { - ps.items[pos] = value + if pos, exists := o.itemPos[key]; exists { + o.items[pos] = value return } // Insert new entry - ps.itemPos[key] = len(ps.items) - ps.items = append(ps.items, value) - ps.size++ + o.itemPos[key] = len(o.items) + o.items = append(o.items, value) } -func (ps *OrderedSet[K, V]) Get(key K) (V, bool) { - ps.lock.RLock() - defer ps.lock.RUnlock() +func (o *OrderedSet[K, V]) Get(key K) (V, bool) { + o.lock.RLock() + defer o.lock.RUnlock() - if pos, ok := ps.itemPos[key]; ok { - return ps.items[pos], true + if pos, ok := o.itemPos[key]; ok { + return o.items[pos], true } var zero V return zero, false } -func (ps *OrderedSet[K, V]) Size() int { - ps.lock.RLock() - defer ps.lock.RUnlock() +func (o *OrderedSet[K, V]) Size() int { + o.lock.RLock() + defer o.lock.RUnlock() - return ps.size + return len(o.items) } // List returns a shallow copy of the proof set's value list. -func (ps *OrderedSet[K, V]) List() []V { - ps.lock.RLock() - defer ps.lock.RUnlock() +func (o *OrderedSet[K, V]) List() []V { + o.lock.RLock() + defer o.lock.RUnlock() - values := make([]V, len(ps.items)) - copy(values, ps.items) + values := make([]V, len(o.items)) + copy(values, o.items) return values }