From 5fbc875d85842f061cec1ee65cdc7e3c286cf022 Mon Sep 17 00:00:00 2001 From: Ryan <80392855+RayXpub@users.noreply.github.com> Date: Mon, 30 Sep 2024 22:23:42 +0400 Subject: [PATCH] feat: enumerable bytes set library --- .../shared/enumerable/EnumerableBytesSet.sol | 135 ++++++++++++++++++ .../test/enumerable/EnumerableBytesSet.t.sol | 135 ++++++++++++++++++ 2 files changed, 270 insertions(+) create mode 100644 contracts/src/v0.8/shared/enumerable/EnumerableBytesSet.sol create mode 100644 contracts/src/v0.8/shared/test/enumerable/EnumerableBytesSet.t.sol diff --git a/contracts/src/v0.8/shared/enumerable/EnumerableBytesSet.sol b/contracts/src/v0.8/shared/enumerable/EnumerableBytesSet.sol new file mode 100644 index 00000000000..36f1102695b --- /dev/null +++ b/contracts/src/v0.8/shared/enumerable/EnumerableBytesSet.sol @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +/* solhint-disable-next-line chainlink-solidity/prefix-internal-functions-with-underscore */ +pragma solidity ^0.8.0; + +/// Library for managing sets of bytes. Reuses OpenZeppelin's EnumerableSet library logic but for bytes. +library EnumerableBytesSet { + struct BytesSet { + bytes[] _values; + mapping(bytes value => uint256) _positions; + } + + /// @dev Adds a value to a set. O(1). + /// @param set The set to add the value to. + /// @param value The value to add. + /// @return True if the value was added to the set, false if the value was already in the set. + function add(BytesSet storage set, bytes memory value) internal returns (bool) { + return _add(set, value); + } + + function _add(BytesSet storage set, bytes memory value) private returns (bool) { + if (!_contains(set, value)) { + set._values.push(value); + // The value is stored at length-1, but we add 1 to all indexes + // and use 0 as a sentinel value + set._positions[value] = set._values.length; + return true; + } else { + return false; + } + } + + /// @dev Removes a value from a set. O(1). + /// @param set The set to remove the value from. + /// @param value The value to remove. + /// @return True if the value was removed from the set, false if the value was not in the set. + function remove(BytesSet storage set, bytes memory value) internal returns (bool) { + return _remove(set, value); + } + + function _remove(BytesSet storage set, bytes memory value) private returns (bool) { + // We cache the value's position to prevent multiple reads from the same storage slot + uint256 position = set._positions[value]; + + if (position != 0) { + // Equivalent to contains(set, value) + // To delete an element from the _values array in O(1), we swap the element to delete with the last one in + // the array, and then remove the last element (sometimes called as 'swap and pop'). + // This modifies the order of the array, as noted in {at}. + + uint256 valueIndex = position - 1; + uint256 lastIndex = set._values.length - 1; + + if (valueIndex != lastIndex) { + bytes memory lastValue = set._values[lastIndex]; + + // Move the lastValue to the index where the value to delete is + set._values[valueIndex] = lastValue; + // Update the tracked position of the lastValue (that was just moved) + set._positions[lastValue] = position; + } + + // Delete the slot where the moved value was stored + set._values.pop(); + + // Delete the tracked position for the deleted slot + delete set._positions[value]; + + return true; + } else { + return false; + } + } + + /// @dev Checks if a value is in a set. O(1). + /// @param set The set to check the value in. + /// @param value The value to check. + /// @return True if the value is in the set, false otherwise. + function contains(BytesSet storage set, bytes memory value) internal view returns (bool) { + return _contains(set, value); + } + + function _contains(BytesSet storage set, bytes memory value) private view returns (bool) { + return set._positions[value] != 0; + } + + /// @dev Returns the number of values in the set. O(1). + /// @param set The set to count values in. + /// @return The number of values in the set. + function length(BytesSet storage set) internal view returns (uint256) { + return _length(set); + } + + function _length(BytesSet storage set) private view returns (uint256) { + return set._values.length; + } + + /// @dev Returns the value stored at position `index` in the set. O(1). + /// Note that there are no guarantees on the ordering of values inside the array, and it may change when more values + /// are added or removed. + /// @dev precondition - `index` must be strictly less than {length}. + /// @param set The set to get the value from. + /// @param index The index to get the value at. + /// @return The value stored at the specified index. + function at(BytesSet storage set, uint256 index) internal view returns (bytes memory) { + return _at(set, index); + } + + function _at(BytesSet storage set, uint256 index) private view returns (bytes memory) { + return set._values[index]; + } + + /// @dev Returns the entire set in an array + /// + /// WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed to + /// mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that this + /// function has an unbounded cost, and using it as part of a state-changing function may render the function + /// uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block. + /// @param set The set to get the values from. + /// + /// @return An array containing all the values in the set. + function values(BytesSet storage set) internal view returns (bytes[] memory) { + bytes[] memory store = _values(set); + bytes[] memory result; + + assembly ("memory-safe") { + result := store + } + + return result; + } + + function _values(BytesSet storage set) private view returns (bytes[] memory) { + return set._values; + } +} diff --git a/contracts/src/v0.8/shared/test/enumerable/EnumerableBytesSet.t.sol b/contracts/src/v0.8/shared/test/enumerable/EnumerableBytesSet.t.sol new file mode 100644 index 00000000000..e486695b547 --- /dev/null +++ b/contracts/src/v0.8/shared/test/enumerable/EnumerableBytesSet.t.sol @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +pragma solidity 0.8.19; + +import {EnumerableBytesSet} from "../../enumerable/EnumerableBytesSet.sol"; + +import {Test} from "../../../vendor/forge-std/src/Test.sol"; + +contract EnumerableBytesSetTest is Test { + function _assertBytesArrayEq(bytes[] memory a, bytes[] memory b) internal { + assertEq(a.length, b.length); + for (uint256 i = 0; i < a.length; i++) { + assertEq(a[i], b[i]); + } + } +} + +contract EnumerableBytesSetTest_Add is EnumerableBytesSetTest { + using EnumerableBytesSet for EnumerableBytesSet.BytesSet; + + EnumerableBytesSet.BytesSet private s_set; + + function test_add_SingleValue() public { + bytes memory value = "value"; + bytes[] memory expected = new bytes[](1); + expected[0] = value; + + assertFalse(s_set.contains(value)); + assertTrue(s_set.add(value)); + assertEq(s_set.length(), 1); + assertEq(s_set.at(0), value); + assertTrue(s_set.contains(value)); + _assertBytesArrayEq(s_set.values(), expected); + } + + function test_add_AlreadyExistingValue() public { + bytes memory value = "value"; + bytes[] memory expected = new bytes[](1); + expected[0] = value; + + assertTrue(s_set.add(value)); + assertFalse(s_set.add(value)); + assertEq(s_set.length(), 1); + assertEq(s_set.at(0), value); + assertTrue(s_set.contains(value)); + _assertBytesArrayEq(s_set.values(), expected); + } + + function test_add_MultipleUniqueValues() public { + bytes memory value1 = "value1"; + bytes memory value2 = "value2"; + bytes[] memory expected = new bytes[](2); + expected[0] = value1; + expected[1] = value2; + + assertTrue(s_set.add(value1)); + assertTrue(s_set.add(value2)); + assertEq(s_set.length(), 2); + assertTrue(s_set.contains(value1)); + assertTrue(s_set.contains(value2)); + assertEq(s_set.at(0), value1); + assertEq(s_set.at(1), value2); + _assertBytesArrayEq(s_set.values(), expected); + } + + function testFuzz_add(bytes[2] memory values) public { + bytes[] memory expected = new bytes[](values.length); + + for (uint256 i = 0; i < values.length; ++i) { + // Ensure uniqueness + expected[i] = bytes.concat(values[i], abi.encodePacked(i)); + s_set.add(expected[i]); + + assertEq(s_set.at(i), expected[i]); + assertTrue(s_set.contains(expected[i])); + } + + assertEq(s_set.length(), values.length); + _assertBytesArrayEq(s_set.values(), expected); + } +} + +contract EnumerableBytesSet_Remove is EnumerableBytesSetTest { + using EnumerableBytesSet for EnumerableBytesSet.BytesSet; + + EnumerableBytesSet.BytesSet private s_set; + + function setUp() public { + s_set.add("value1"); + s_set.add("value2"); + } + + function test_remove_SingleExistingValue() public { + bytes memory value = "value1"; + bytes[] memory expected = new bytes[](1); + expected[0] = "value2"; + + assertTrue(s_set.remove(value)); + assertEq(s_set.length(), 1); + assertFalse(s_set.contains(value)); + assertEq(s_set.at(0), "value2"); + _assertBytesArrayEq(s_set.values(), expected); + } + + function test_remove_MultipleExistingValues() public { + bytes memory value1 = "value1"; + bytes memory value2 = "value2"; + bytes[] memory expected = new bytes[](0); + + vm.expectRevert(); + assertEq(s_set.at(0), ""); + vm.expectRevert(); + assertEq(s_set.at(1), ""); + + assertTrue(s_set.remove(value1)); + assertTrue(s_set.remove(value2)); + assertEq(s_set.length(), 0); + assertFalse(s_set.contains(value1)); + assertFalse(s_set.contains(value2)); + _assertBytesArrayEq(s_set.values(), expected); + } + + function test_remove_SingleNonExistingValue() public { + bytes memory value = "value3"; + bytes[] memory expected = new bytes[](2); + expected[0] = "value1"; + expected[1] = "value2"; + + assertFalse(s_set.remove(value)); + assertEq(s_set.length(), 2); + assertFalse(s_set.contains(value)); + assertEq(s_set.at(0), "value1"); + assertEq(s_set.at(1), "value2"); + _assertBytesArrayEq(s_set.values(), expected); + } +}