diff --git a/merkle_zeppelin/__init__.py b/merkle_zeppelin/__init__.py index a1e966d..10a4c85 100644 --- a/merkle_zeppelin/__init__.py +++ b/merkle_zeppelin/__init__.py @@ -1,4 +1,18 @@ +from .data_io.base_exporter import MerkleTreeExporter +from .data_io.base_importer import MerkleTreeImporter +from .data_io.dto import LeafValueDTO, MerkleTreeDTO +from .data_io.json_exporter import MerkleTreeJSONExporter +from .data_io.json_importer import MerkleTreeJSONImporter from .trees.binary import BinaryTree from .trees.merkle import MerkleTree -__all__ = ["BinaryTree", "MerkleTree"] +__all__ = [ + "BinaryTree", + "MerkleTree", + "LeafValueDTO", + "MerkleTreeDTO", + "MerkleTreeImporter", + "MerkleTreeExporter", + "MerkleTreeJSONImporter", + "MerkleTreeJSONExporter", +] diff --git a/merkle_zeppelin/data_io/__init__.py b/merkle_zeppelin/data_io/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/merkle_zeppelin/data_io/base_exporter.py b/merkle_zeppelin/data_io/base_exporter.py new file mode 100644 index 0000000..39bb415 --- /dev/null +++ b/merkle_zeppelin/data_io/base_exporter.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod +from typing import Any + +from .dto import MerkleTreeDTO + + +class MerkleTreeExporter(ABC): + @staticmethod + @abstractmethod + def export_tree(data: MerkleTreeDTO) -> Any: + ... diff --git a/merkle_zeppelin/data_io/base_importer.py b/merkle_zeppelin/data_io/base_importer.py new file mode 100644 index 0000000..3f02a21 --- /dev/null +++ b/merkle_zeppelin/data_io/base_importer.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod +from typing import Any + +from .dto import MerkleTreeDTO + + +class MerkleTreeImporter(ABC): + @staticmethod + @abstractmethod + def import_tree(data: Any) -> MerkleTreeDTO: + ... diff --git a/merkle_zeppelin/data_io/dto.py b/merkle_zeppelin/data_io/dto.py new file mode 100644 index 0000000..721f4da --- /dev/null +++ b/merkle_zeppelin/data_io/dto.py @@ -0,0 +1,22 @@ +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field +from pydantic.alias_generators import to_camel + +Leaf = tuple[Any, ...] + + +class LeafValueDTO(BaseModel): + model_config = ConfigDict(alias_generator=to_camel) + + value: Leaf + tree_index: int + + +class MerkleTreeDTO(BaseModel): + model_config = ConfigDict(alias_generator=to_camel) + + format: str = Field("standard-v1", frozen=True) + tree: list[str] + values: list[LeafValueDTO] + leaf_encoding: list[str] diff --git a/merkle_zeppelin/data_io/json_exporter.py b/merkle_zeppelin/data_io/json_exporter.py new file mode 100644 index 0000000..a65a782 --- /dev/null +++ b/merkle_zeppelin/data_io/json_exporter.py @@ -0,0 +1,8 @@ +from .base_exporter import MerkleTreeExporter +from .dto import MerkleTreeDTO + + +class MerkleTreeJSONExporter(MerkleTreeExporter): + @staticmethod + def export_tree(data: MerkleTreeDTO) -> str: + return data.model_dump_json(by_alias=True) diff --git a/merkle_zeppelin/data_io/json_importer.py b/merkle_zeppelin/data_io/json_importer.py new file mode 100644 index 0000000..892857f --- /dev/null +++ b/merkle_zeppelin/data_io/json_importer.py @@ -0,0 +1,8 @@ +from .base_importer import MerkleTreeImporter +from .dto import MerkleTreeDTO + + +class MerkleTreeJSONImporter(MerkleTreeImporter): + @staticmethod + def import_tree(data: str) -> MerkleTreeDTO: + return MerkleTreeDTO.model_validate_json(data) diff --git a/merkle_zeppelin/trees/merkle.py b/merkle_zeppelin/trees/merkle.py index f0db338..bf694da 100644 --- a/merkle_zeppelin/trees/merkle.py +++ b/merkle_zeppelin/trees/merkle.py @@ -1,13 +1,17 @@ -from typing import Any, Callable, Union +from __future__ import annotations + +from operator import itemgetter +from typing import Any, Callable, Optional, Type, Union from Crypto.Hash import keccak from eth_abi import encode +from ..data_io.base_exporter import MerkleTreeExporter +from ..data_io.base_importer import MerkleTreeImporter +from ..data_io.dto import Leaf, LeafValueDTO, MerkleTreeDTO from .binary import BinaryTree from .exceptions import MerkleTreeValidationFailed, ValueNotFoundInTree -Leaf = tuple[Any, ...] - def keccak256(v: bytes) -> bytes: return keccak.new(data=v, digest_bits=256).digest() @@ -16,32 +20,92 @@ def keccak256(v: bytes) -> bytes: class MerkleTree(BinaryTree[bytes]): def __init__( self, - leaves: list[Leaf], + raw_elements: list[Leaf], types: list[str], - hashing_function: Callable[[bytes], bytes] = None, + hashing_function: Optional[Callable[[bytes], bytes]] = None, ) -> None: self._hashing_function = hashing_function or keccak256 self._types = types - hash_leaf_pairs = self._get_hash_leaf_pairs(leaves) - hash_leaf_pairs.sort(reverse=True) + sorted_index_hash_pairs = self._get_sorted_index_hash_pairs(raw_elements) + sorted_hashes = [hash_ for index, hash_ in sorted_index_hash_pairs] - ordered_hashed_leaves = [leaf_hash for leaf_hash, _ in hash_leaf_pairs] - ordered_leaves = [leaf for _, leaf in hash_leaf_pairs] + super().__init__(sorted_hashes) - super().__init__(ordered_hashed_leaves) + self._raw_to_leaves_index_mapping = self._get_raw_to_leaves_index_mapping( + raw_elements, sorted_index_hash_pairs + ) + + @staticmethod + def get_hash_from_string(hash_: str) -> bytes: + if hash_.startswith("0x"): + hash_ = hash_[2:] + + return bytes.fromhex(hash_) + + @classmethod + def import_tree( + cls, + data: Any, + importer: Type[MerkleTreeImporter], + validate: bool = True, + hashing_function: Optional[Callable[[bytes], bytes]] = None, + ) -> MerkleTree: + dto = importer.import_tree(data) + return cls.import_tree_from_dto(dto, validate, hashing_function) + + @classmethod + def import_tree_from_dto( + cls, + dto: MerkleTreeDTO, + validate: bool = True, + hashing_function: Optional[Callable[[bytes], bytes]] = None, + ) -> MerkleTree: + obj = cls.__new__(cls) + obj._types = dto.leaf_encoding + obj._hashing_function = hashing_function or keccak256 + + obj._nodes = [cls.get_hash_from_string(hash_str) for hash_str in dto.tree] + obj._raw_to_leaves_index_mapping = { + value.value: value.tree_index for value in dto.values + } + + if validate: + cls.validate(obj) + + return obj + + @property + def dto(self) -> MerkleTreeDTO: + return MerkleTreeDTO( + tree=[f"0x{node.hex()}" for node in self._nodes], + values=[ + LeafValueDTO(value=value, treeIndex=index) + for value, index in self._raw_to_leaves_index_mapping.items() + ], + leafEncoding=self._types, + ) - self._raw_leaves_index = self._get_raw_leaves_index(ordered_leaves) + def export_tree(self, exporter: Type[MerkleTreeExporter]) -> Any: + return exporter.export_tree(self.dto) def validate(self, raise_exception: bool = True) -> bool: - for i in range(1, self._inner_nodes_number): - left_node_index = self._get_left_child_index(i) - right_node_index = self._get_right_child_index(i) + calculated_leaves = sorted( + [ + self._calculate_leaf_hash(value) + for value in self._raw_to_leaves_index_mapping.keys() + ], + reverse=True, + ) + if self.leaves != calculated_leaves: + raise MerkleTreeValidationFailed() + + for checked, i in enumerate(range(len(self._nodes) - 1, 0, -2)): calculated_parent = self._calculate_parent_value( - self._nodes[left_node_index], self._nodes[right_node_index] + self._nodes[i], self._nodes[i - 1] ) - if self._nodes[i] != calculated_parent: + if self._nodes[self._inner_nodes_number - checked - 1] != calculated_parent: if not raise_exception: return False @@ -51,7 +115,7 @@ def validate(self, raise_exception: bool = True) -> bool: def get_proofs(self, value: Leaf) -> Union[list[bytes], None]: try: - node_index = self._raw_leaves_index[value] + node_index = self._raw_to_leaves_index_mapping[value] except ValueError: raise ValueNotFoundInTree(value) @@ -64,19 +128,38 @@ def get_proofs(self, value: Leaf) -> Union[list[bytes], None]: return result + def _get_sorted_index_hash_pairs( + self, raw_elements: list[Leaf] + ) -> list[tuple[int, bytes]]: + hashes = [self._calculate_leaf_hash(el) for el in raw_elements] + + return sorted( + enumerate(hashes), + reverse=True, + key=itemgetter(1), # sort by hash (second position) + ) + + def _get_raw_to_leaves_index_mapping( + self, raw_elements: list[Leaf], ordered_hashed_leaves: list[tuple[int, bytes]] + ) -> dict[Leaf, int]: + elements_number = len(raw_elements) + + related_indexed_when_sorted = sorted( + range(elements_number), key=lambda i: ordered_hashed_leaves[i] + ) + + return { + leaf: index_in_sorted_by_hash + self._inner_nodes_number + for leaf, index_in_sorted_by_hash in zip( + raw_elements, related_indexed_when_sorted + ) + } + def _calculate_leaf_hash(self, value: Leaf) -> bytes: return self._hashing_function( self._hashing_function(encode(self._types, value)) ) - def _get_hash_leaf_pairs( - self, input_leaves: list[Leaf] - ) -> list[tuple[bytes, Leaf]]: - return [(self._calculate_leaf_hash(leaf), leaf) for leaf in input_leaves] - - def _get_raw_leaves_index(self, hash_leaf_pairs: list[Leaf]) -> dict[Leaf, int]: - return {leaf: self._get_node_index(i) for i, leaf in enumerate(hash_leaf_pairs)} - def _calculate_parent_value(self, left_child: bytes, right_child: bytes) -> bytes: child_1, child_2 = sorted([left_child, right_child]) return self._hashing_function(child_1 + child_2) diff --git a/tests/data_io/__init__.py b/tests/data_io/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/data_io/example.py b/tests/data_io/example.py new file mode 100644 index 0000000..dd6f9ac --- /dev/null +++ b/tests/data_io/example.py @@ -0,0 +1,24 @@ +leafs = [ + (123, True), + (71254, False), + (42386, True), +] + +json_dump = """ +{ + "format": "standard-v1", + "tree": [ + "0x045578d5e654c3f10fecce520b17f5b8073630bf780db850721b1a8a5df3b839", + "0x9385965ba65029c2cbdb3782f35b755a76788ef236602b98118ef535cca36e5c", + "0xee7349a2ed7003d5da5bcfdc43253f4a7b757d72ab681f3178469b73087ac3e7", + "0x8b1d412fe317e16a1c98414d61c95cd2e44ba6b16907e90c2b0035cf6261d01d", + "0x39ee1707f21ec11bac8c0d42538c09f71c0fe3ceb0a2f6be012c2769a714af3d" + ], + "values": [ + {"value": [123, true], "treeIndex": 2}, + {"value": [71254, false], "treeIndex": 4}, + {"value": [42386, true], "treeIndex": 3} + ], + "leafEncoding": ["int256", "bool"] +} +""" diff --git a/tests/data_io/test_json_exporter.py b/tests/data_io/test_json_exporter.py new file mode 100644 index 0000000..18a4aae --- /dev/null +++ b/tests/data_io/test_json_exporter.py @@ -0,0 +1,13 @@ +from merkle_zeppelin import MerkleTree, MerkleTreeJSONExporter + +from ..utils import remove_whitespaces +from .example import json_dump, leafs + + +def test_json_export() -> None: + # when + tree = MerkleTree(leafs, ["int256", "bool"]) + exported_tree = tree.export_tree(MerkleTreeJSONExporter) + + # then + assert remove_whitespaces(exported_tree) == remove_whitespaces(json_dump) diff --git a/tests/data_io/test_json_importer.py b/tests/data_io/test_json_importer.py new file mode 100644 index 0000000..c36cebc --- /dev/null +++ b/tests/data_io/test_json_importer.py @@ -0,0 +1,18 @@ +import json + +from merkle_zeppelin import MerkleTree, MerkleTreeJSONImporter + +from ..utils import remove_0x +from .example import json_dump, leafs + + +def test_json_import() -> None: + # given + tree_input = json.loads(json_dump) + + # when + tree = MerkleTree.import_tree(json_dump, MerkleTreeJSONImporter, validate=True) + + # then + assert tree.root.hex() == remove_0x(tree_input["tree"][0]) + assert len(tree._nodes) == (len(leafs) * 2) - 1 diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..c1b8d96 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,9 @@ +import string + + +def remove_0x(data: str) -> str: + return data[2:] if data.startswith("0x") else data + + +def remove_whitespaces(data: str) -> str: + return data.translate(str.maketrans("", "", string.whitespace))