Skip to content

Commit

Permalink
feature: add importer and exporter (#2)
Browse files Browse the repository at this point in the history
Co-authored-by: mokrzesa <[email protected]>
Co-authored-by: Sergii Denysiuk <[email protected]>
  • Loading branch information
3 people authored Nov 24, 2023
1 parent f7c1cb3 commit 973ac60
Show file tree
Hide file tree
Showing 13 changed files with 247 additions and 26 deletions.
16 changes: 15 additions & 1 deletion merkle_zeppelin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
from .data_io.base_exporter import MerkleTreeExporter
from .data_io.base_importer import MerkleTreeImporter
from .data_io.dto import LeafValueDTO, MerkleTreeDTO
from .data_io.json_exporter import MerkleTreeJSONExporter
from .data_io.json_importer import MerkleTreeJSONImporter
from .trees.binary import BinaryTree
from .trees.merkle import MerkleTree

__all__ = ["BinaryTree", "MerkleTree"]
__all__ = [
"BinaryTree",
"MerkleTree",
"LeafValueDTO",
"MerkleTreeDTO",
"MerkleTreeImporter",
"MerkleTreeExporter",
"MerkleTreeJSONImporter",
"MerkleTreeJSONExporter",
]
Empty file.
11 changes: 11 additions & 0 deletions merkle_zeppelin/data_io/base_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
from typing import Any

from .dto import MerkleTreeDTO


class MerkleTreeExporter(ABC):
@staticmethod
@abstractmethod
def export_tree(data: MerkleTreeDTO) -> Any:
...
11 changes: 11 additions & 0 deletions merkle_zeppelin/data_io/base_importer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
from typing import Any

from .dto import MerkleTreeDTO


class MerkleTreeImporter(ABC):
@staticmethod
@abstractmethod
def import_tree(data: Any) -> MerkleTreeDTO:
...
22 changes: 22 additions & 0 deletions merkle_zeppelin/data_io/dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any

from pydantic import BaseModel, ConfigDict, Field
from pydantic.alias_generators import to_camel

Leaf = tuple[Any, ...]


class LeafValueDTO(BaseModel):
model_config = ConfigDict(alias_generator=to_camel)

value: Leaf
tree_index: int


class MerkleTreeDTO(BaseModel):
model_config = ConfigDict(alias_generator=to_camel)

format: str = Field("standard-v1", frozen=True)
tree: list[str]
values: list[LeafValueDTO]
leaf_encoding: list[str]
8 changes: 8 additions & 0 deletions merkle_zeppelin/data_io/json_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .base_exporter import MerkleTreeExporter
from .dto import MerkleTreeDTO


class MerkleTreeJSONExporter(MerkleTreeExporter):
@staticmethod
def export_tree(data: MerkleTreeDTO) -> str:
return data.model_dump_json(by_alias=True)
8 changes: 8 additions & 0 deletions merkle_zeppelin/data_io/json_importer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .base_importer import MerkleTreeImporter
from .dto import MerkleTreeDTO


class MerkleTreeJSONImporter(MerkleTreeImporter):
@staticmethod
def import_tree(data: str) -> MerkleTreeDTO:
return MerkleTreeDTO.model_validate_json(data)
133 changes: 108 additions & 25 deletions merkle_zeppelin/trees/merkle.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import Any, Callable, Union
from __future__ import annotations

from operator import itemgetter
from typing import Any, Callable, Optional, Type, Union

from Crypto.Hash import keccak
from eth_abi import encode

from ..data_io.base_exporter import MerkleTreeExporter
from ..data_io.base_importer import MerkleTreeImporter
from ..data_io.dto import Leaf, LeafValueDTO, MerkleTreeDTO
from .binary import BinaryTree
from .exceptions import MerkleTreeValidationFailed, ValueNotFoundInTree

Leaf = tuple[Any, ...]


def keccak256(v: bytes) -> bytes:
return keccak.new(data=v, digest_bits=256).digest()
Expand All @@ -16,32 +20,92 @@ def keccak256(v: bytes) -> bytes:
class MerkleTree(BinaryTree[bytes]):
def __init__(
self,
leaves: list[Leaf],
raw_elements: list[Leaf],
types: list[str],
hashing_function: Callable[[bytes], bytes] = None,
hashing_function: Optional[Callable[[bytes], bytes]] = None,
) -> None:
self._hashing_function = hashing_function or keccak256
self._types = types

hash_leaf_pairs = self._get_hash_leaf_pairs(leaves)
hash_leaf_pairs.sort(reverse=True)
sorted_index_hash_pairs = self._get_sorted_index_hash_pairs(raw_elements)
sorted_hashes = [hash_ for index, hash_ in sorted_index_hash_pairs]

ordered_hashed_leaves = [leaf_hash for leaf_hash, _ in hash_leaf_pairs]
ordered_leaves = [leaf for _, leaf in hash_leaf_pairs]
super().__init__(sorted_hashes)

super().__init__(ordered_hashed_leaves)
self._raw_to_leaves_index_mapping = self._get_raw_to_leaves_index_mapping(
raw_elements, sorted_index_hash_pairs
)

@staticmethod
def get_hash_from_string(hash_: str) -> bytes:
if hash_.startswith("0x"):
hash_ = hash_[2:]

return bytes.fromhex(hash_)

@classmethod
def import_tree(
cls,
data: Any,
importer: Type[MerkleTreeImporter],
validate: bool = True,
hashing_function: Optional[Callable[[bytes], bytes]] = None,
) -> MerkleTree:
dto = importer.import_tree(data)
return cls.import_tree_from_dto(dto, validate, hashing_function)

@classmethod
def import_tree_from_dto(
cls,
dto: MerkleTreeDTO,
validate: bool = True,
hashing_function: Optional[Callable[[bytes], bytes]] = None,
) -> MerkleTree:
obj = cls.__new__(cls)
obj._types = dto.leaf_encoding
obj._hashing_function = hashing_function or keccak256

obj._nodes = [cls.get_hash_from_string(hash_str) for hash_str in dto.tree]
obj._raw_to_leaves_index_mapping = {
value.value: value.tree_index for value in dto.values
}

if validate:
cls.validate(obj)

return obj

@property
def dto(self) -> MerkleTreeDTO:
return MerkleTreeDTO(
tree=[f"0x{node.hex()}" for node in self._nodes],
values=[
LeafValueDTO(value=value, treeIndex=index)
for value, index in self._raw_to_leaves_index_mapping.items()
],
leafEncoding=self._types,
)

self._raw_leaves_index = self._get_raw_leaves_index(ordered_leaves)
def export_tree(self, exporter: Type[MerkleTreeExporter]) -> Any:
return exporter.export_tree(self.dto)

def validate(self, raise_exception: bool = True) -> bool:
for i in range(1, self._inner_nodes_number):
left_node_index = self._get_left_child_index(i)
right_node_index = self._get_right_child_index(i)
calculated_leaves = sorted(
[
self._calculate_leaf_hash(value)
for value in self._raw_to_leaves_index_mapping.keys()
],
reverse=True,
)
if self.leaves != calculated_leaves:
raise MerkleTreeValidationFailed()

for checked, i in enumerate(range(len(self._nodes) - 1, 0, -2)):
calculated_parent = self._calculate_parent_value(
self._nodes[left_node_index], self._nodes[right_node_index]
self._nodes[i], self._nodes[i - 1]
)

if self._nodes[i] != calculated_parent:
if self._nodes[self._inner_nodes_number - checked - 1] != calculated_parent:
if not raise_exception:
return False

Expand All @@ -51,7 +115,7 @@ def validate(self, raise_exception: bool = True) -> bool:

def get_proofs(self, value: Leaf) -> Union[list[bytes], None]:
try:
node_index = self._raw_leaves_index[value]
node_index = self._raw_to_leaves_index_mapping[value]
except ValueError:
raise ValueNotFoundInTree(value)

Expand All @@ -64,19 +128,38 @@ def get_proofs(self, value: Leaf) -> Union[list[bytes], None]:

return result

def _get_sorted_index_hash_pairs(
self, raw_elements: list[Leaf]
) -> list[tuple[int, bytes]]:
hashes = [self._calculate_leaf_hash(el) for el in raw_elements]

return sorted(
enumerate(hashes),
reverse=True,
key=itemgetter(1), # sort by hash (second position)
)

def _get_raw_to_leaves_index_mapping(
self, raw_elements: list[Leaf], ordered_hashed_leaves: list[tuple[int, bytes]]
) -> dict[Leaf, int]:
elements_number = len(raw_elements)

related_indexed_when_sorted = sorted(
range(elements_number), key=lambda i: ordered_hashed_leaves[i]
)

return {
leaf: index_in_sorted_by_hash + self._inner_nodes_number
for leaf, index_in_sorted_by_hash in zip(
raw_elements, related_indexed_when_sorted
)
}

def _calculate_leaf_hash(self, value: Leaf) -> bytes:
return self._hashing_function(
self._hashing_function(encode(self._types, value))
)

def _get_hash_leaf_pairs(
self, input_leaves: list[Leaf]
) -> list[tuple[bytes, Leaf]]:
return [(self._calculate_leaf_hash(leaf), leaf) for leaf in input_leaves]

def _get_raw_leaves_index(self, hash_leaf_pairs: list[Leaf]) -> dict[Leaf, int]:
return {leaf: self._get_node_index(i) for i, leaf in enumerate(hash_leaf_pairs)}

def _calculate_parent_value(self, left_child: bytes, right_child: bytes) -> bytes:
child_1, child_2 = sorted([left_child, right_child])
return self._hashing_function(child_1 + child_2)
Empty file added tests/data_io/__init__.py
Empty file.
24 changes: 24 additions & 0 deletions tests/data_io/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
leafs = [
(123, True),
(71254, False),
(42386, True),
]

json_dump = """
{
"format": "standard-v1",
"tree": [
"0x045578d5e654c3f10fecce520b17f5b8073630bf780db850721b1a8a5df3b839",
"0x9385965ba65029c2cbdb3782f35b755a76788ef236602b98118ef535cca36e5c",
"0xee7349a2ed7003d5da5bcfdc43253f4a7b757d72ab681f3178469b73087ac3e7",
"0x8b1d412fe317e16a1c98414d61c95cd2e44ba6b16907e90c2b0035cf6261d01d",
"0x39ee1707f21ec11bac8c0d42538c09f71c0fe3ceb0a2f6be012c2769a714af3d"
],
"values": [
{"value": [123, true], "treeIndex": 2},
{"value": [71254, false], "treeIndex": 4},
{"value": [42386, true], "treeIndex": 3}
],
"leafEncoding": ["int256", "bool"]
}
"""
13 changes: 13 additions & 0 deletions tests/data_io/test_json_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from merkle_zeppelin import MerkleTree, MerkleTreeJSONExporter

from ..utils import remove_whitespaces
from .example import json_dump, leafs


def test_json_export() -> None:
# when
tree = MerkleTree(leafs, ["int256", "bool"])
exported_tree = tree.export_tree(MerkleTreeJSONExporter)

# then
assert remove_whitespaces(exported_tree) == remove_whitespaces(json_dump)
18 changes: 18 additions & 0 deletions tests/data_io/test_json_importer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import json

from merkle_zeppelin import MerkleTree, MerkleTreeJSONImporter

from ..utils import remove_0x
from .example import json_dump, leafs


def test_json_import() -> None:
# given
tree_input = json.loads(json_dump)

# when
tree = MerkleTree.import_tree(json_dump, MerkleTreeJSONImporter, validate=True)

# then
assert tree.root.hex() == remove_0x(tree_input["tree"][0])
assert len(tree._nodes) == (len(leafs) * 2) - 1
9 changes: 9 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import string


def remove_0x(data: str) -> str:
return data[2:] if data.startswith("0x") else data


def remove_whitespaces(data: str) -> str:
return data.translate(str.maketrans("", "", string.whitespace))

0 comments on commit 973ac60

Please sign in to comment.