diff --git a/contracts/libs/data-structures/IncrementalMerkleTree.sol b/contracts/libs/data-structures/IncrementalMerkleTree.sol index 493b6ae8..816c2f7c 100644 --- a/contracts/libs/data-structures/IncrementalMerkleTree.sol +++ b/contracts/libs/data-structures/IncrementalMerkleTree.sol @@ -48,6 +48,14 @@ library IncrementalMerkleTree { IMT _tree; } + /** + * @notice The Uint256 Incremental Merkle Tree constructor, creates a tree instance with the + * given size, O(1) complex + */ + function newUint(uint256 size_) internal pure returns (UintIMT memory imt) { + imt._tree = _new(size_); + } + /** * @notice The function to add a new element to the uint256 tree. * Complexity is O(log(n)), where n is the number of elements in the tree. @@ -123,6 +131,14 @@ library IncrementalMerkleTree { IMT _tree; } + /** + * @notice The Bytes32 Incremental Merkle Tree constructor, creates a tree instance with the + * given size, O(1) complex + */ + function newBytes32(uint256 size_) internal pure returns (Bytes32IMT memory imt) { + imt._tree = _new(size_); + } + /** * @notice The function to add a new element to the bytes32 tree. * Complexity is O(log(n)), where n is the number of elements in the tree. @@ -188,6 +204,14 @@ library IncrementalMerkleTree { IMT _tree; } + /** + * @notice The Address Incremental Merkle Tree constructor, creates a tree instance with the + * given size, O(1) complex + */ + function newAddress(uint256 size_) internal pure returns (AddressIMT memory imt) { + imt._tree = _new(size_); + } + /** * @notice The function to add a new element to the address tree. * Complexity is O(log(n)), where n is the number of elements in the tree. @@ -257,6 +281,10 @@ library IncrementalMerkleTree { function(bytes32, bytes32) view returns (bytes32) hash2Fn; } + function _new(uint256 size_) private pure returns (IMT memory imt) { + imt.branches = new bytes32[](size_); + } + function _setHashers( IMT storage tree, function(bytes32) view returns (bytes32) hash1Fn_, diff --git a/contracts/mock/libs/data-structures/IncrementalMerkleTreeMock.sol b/contracts/mock/libs/data-structures/IncrementalMerkleTreeMock.sol index 717f3610..4c9ab82d 100644 --- a/contracts/mock/libs/data-structures/IncrementalMerkleTreeMock.sol +++ b/contracts/mock/libs/data-structures/IncrementalMerkleTreeMock.sol @@ -18,6 +18,18 @@ contract IncrementalMerkleTreeMock { IncrementalMerkleTree.Bytes32IMT internal _bytes32Tree; IncrementalMerkleTree.AddressIMT internal _addressTree; + function newUintTree(uint256 treeHeight_) external { + _uintTree = IncrementalMerkleTree.newUint(treeHeight_); + } + + function newBytes32Tree(uint256 treeHeight_) external { + _bytes32Tree = IncrementalMerkleTree.newBytes32(treeHeight_); + } + + function newAddressTree(uint256 treeHeight_) external { + _addressTree = IncrementalMerkleTree.newAddress(treeHeight_); + } + function addUint(uint256 element_) external { _uintTree.add(element_); } diff --git a/test/libs/data-structures/IncrementalMerkleTree.test.ts b/test/libs/data-structures/IncrementalMerkleTree.test.ts index 44d30bda..38962f05 100644 --- a/test/libs/data-structures/IncrementalMerkleTree.test.ts +++ b/test/libs/data-structures/IncrementalMerkleTree.test.ts @@ -54,6 +54,22 @@ describe("IncrementalMerkleTree", () => { } describe("Uint IMT", () => { + it("should build a Merkle Tree with predefined size", async () => { + await merkleTree.newUintTree(10); + + const element = 2341; + + await merkleTree.addUint(element); + + const elementHash = getUintElementHash(element); + + localMerkleTree = buildSparseMerkleTree([elementHash], 10); + + expect(await merkleTree.getUintRoot()).to.equal(getRoot(localMerkleTree)); + expect(await merkleTree.getUintTreeLength()).to.equal(1n); + expect(await merkleTree.getUintTreeHeight()).to.equal(10n); + }); + it("should add element to tree", async () => { const element = 1234; @@ -113,6 +129,22 @@ describe("IncrementalMerkleTree", () => { }); describe("Bytes32 IMT", () => { + it("should build a Merkle Tree with predefined size", async () => { + await merkleTree.newBytes32Tree(10); + + const element = ethers.encodeBytes32String(`0x1234`); + + await merkleTree.addBytes32(element); + + const elementHash = getBytes32ElementHash(element); + + localMerkleTree = buildSparseMerkleTree([elementHash], 10); + + expect(await merkleTree.getBytes32Root()).to.equal(getRoot(localMerkleTree)); + expect(await merkleTree.getBytes32TreeLength()).to.equal(1n); + expect(await merkleTree.getBytes32TreeHeight()).to.equal(10n); + }); + it("should add element to tree", async () => { const element = ethers.encodeBytes32String(`0x1234`); @@ -176,6 +208,22 @@ describe("IncrementalMerkleTree", () => { }); describe("Address IMT", () => { + it("should build a Merkle Tree with predefined size", async () => { + await merkleTree.newAddressTree(10); + + const element = USER1.address; + + await merkleTree.addAddress(element); + + const elementHash = getAddressElementHash(element); + + localMerkleTree = buildSparseMerkleTree([elementHash], 10); + + expect(await merkleTree.getAddressRoot()).to.equal(getRoot(localMerkleTree)); + expect(await merkleTree.getAddressTreeLength()).to.equal(1n); + expect(await merkleTree.getAddressTreeHeight()).to.equal(10n); + }); + it("should add element to tree", async () => { const element = USER1.address;