Skip to content

Commit

Permalink
feat: enumerable bytes set library
Browse files Browse the repository at this point in the history
  • Loading branch information
RayXpub committed Sep 30, 2024
1 parent 19690b0 commit 5fbc875
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 0 deletions.
135 changes: 135 additions & 0 deletions contracts/src/v0.8/shared/enumerable/EnumerableBytesSet.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// SPDX-License-Identifier: MIT
/* solhint-disable-next-line chainlink-solidity/prefix-internal-functions-with-underscore */
pragma solidity ^0.8.0;

/// Library for managing sets of bytes. Reuses OpenZeppelin's EnumerableSet library logic but for bytes.
library EnumerableBytesSet {
struct BytesSet {
bytes[] _values;
mapping(bytes value => uint256) _positions;
}

/// @dev Adds a value to a set. O(1).
/// @param set The set to add the value to.
/// @param value The value to add.
/// @return True if the value was added to the set, false if the value was already in the set.
function add(BytesSet storage set, bytes memory value) internal returns (bool) {
return _add(set, value);
}

function _add(BytesSet storage set, bytes memory value) private returns (bool) {
if (!_contains(set, value)) {
set._values.push(value);
// The value is stored at length-1, but we add 1 to all indexes
// and use 0 as a sentinel value
set._positions[value] = set._values.length;
return true;
} else {
return false;
}
}

/// @dev Removes a value from a set. O(1).
/// @param set The set to remove the value from.
/// @param value The value to remove.
/// @return True if the value was removed from the set, false if the value was not in the set.
function remove(BytesSet storage set, bytes memory value) internal returns (bool) {
return _remove(set, value);
}

function _remove(BytesSet storage set, bytes memory value) private returns (bool) {
// We cache the value's position to prevent multiple reads from the same storage slot
uint256 position = set._positions[value];

if (position != 0) {
// Equivalent to contains(set, value)
// To delete an element from the _values array in O(1), we swap the element to delete with the last one in
// the array, and then remove the last element (sometimes called as 'swap and pop').
// This modifies the order of the array, as noted in {at}.

uint256 valueIndex = position - 1;
uint256 lastIndex = set._values.length - 1;

if (valueIndex != lastIndex) {
bytes memory lastValue = set._values[lastIndex];

// Move the lastValue to the index where the value to delete is
set._values[valueIndex] = lastValue;
// Update the tracked position of the lastValue (that was just moved)
set._positions[lastValue] = position;
}

// Delete the slot where the moved value was stored
set._values.pop();

// Delete the tracked position for the deleted slot
delete set._positions[value];

return true;
} else {
return false;
}
}

/// @dev Checks if a value is in a set. O(1).
/// @param set The set to check the value in.
/// @param value The value to check.
/// @return True if the value is in the set, false otherwise.
function contains(BytesSet storage set, bytes memory value) internal view returns (bool) {
return _contains(set, value);
}

function _contains(BytesSet storage set, bytes memory value) private view returns (bool) {
return set._positions[value] != 0;
}

/// @dev Returns the number of values in the set. O(1).
/// @param set The set to count values in.
/// @return The number of values in the set.
function length(BytesSet storage set) internal view returns (uint256) {
return _length(set);
}

function _length(BytesSet storage set) private view returns (uint256) {
return set._values.length;
}

/// @dev Returns the value stored at position `index` in the set. O(1).
/// Note that there are no guarantees on the ordering of values inside the array, and it may change when more values
/// are added or removed.
/// @dev precondition - `index` must be strictly less than {length}.
/// @param set The set to get the value from.
/// @param index The index to get the value at.
/// @return The value stored at the specified index.
function at(BytesSet storage set, uint256 index) internal view returns (bytes memory) {
return _at(set, index);
}

function _at(BytesSet storage set, uint256 index) private view returns (bytes memory) {
return set._values[index];
}

/// @dev Returns the entire set in an array
///
/// WARNING: This operation will copy the entire storage to memory, which can be quite expensive. This is designed to
/// mostly be used by view accessors that are queried without any gas fees. Developers should keep in mind that this
/// function has an unbounded cost, and using it as part of a state-changing function may render the function
/// uncallable if the set grows to a point where copying to memory consumes too much gas to fit in a block.
/// @param set The set to get the values from.
///
/// @return An array containing all the values in the set.
function values(BytesSet storage set) internal view returns (bytes[] memory) {
bytes[] memory store = _values(set);
bytes[] memory result;

assembly ("memory-safe") {
result := store
}

return result;
}

function _values(BytesSet storage set) private view returns (bytes[] memory) {
return set._values;
}
}
135 changes: 135 additions & 0 deletions contracts/src/v0.8/shared/test/enumerable/EnumerableBytesSet.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// SPDX-License-Identifier: MIT
pragma solidity 0.8.19;

import {EnumerableBytesSet} from "../../enumerable/EnumerableBytesSet.sol";

import {Test} from "../../../vendor/forge-std/src/Test.sol";

contract EnumerableBytesSetTest is Test {
function _assertBytesArrayEq(bytes[] memory a, bytes[] memory b) internal {
assertEq(a.length, b.length);
for (uint256 i = 0; i < a.length; i++) {
assertEq(a[i], b[i]);
}
}
}

contract EnumerableBytesSetTest_Add is EnumerableBytesSetTest {
using EnumerableBytesSet for EnumerableBytesSet.BytesSet;

EnumerableBytesSet.BytesSet private s_set;

function test_add_SingleValue() public {
bytes memory value = "value";
bytes[] memory expected = new bytes[](1);
expected[0] = value;

assertFalse(s_set.contains(value));
assertTrue(s_set.add(value));
assertEq(s_set.length(), 1);
assertEq(s_set.at(0), value);
assertTrue(s_set.contains(value));
_assertBytesArrayEq(s_set.values(), expected);
}

function test_add_AlreadyExistingValue() public {
bytes memory value = "value";
bytes[] memory expected = new bytes[](1);
expected[0] = value;

assertTrue(s_set.add(value));
assertFalse(s_set.add(value));
assertEq(s_set.length(), 1);
assertEq(s_set.at(0), value);
assertTrue(s_set.contains(value));
_assertBytesArrayEq(s_set.values(), expected);
}

function test_add_MultipleUniqueValues() public {
bytes memory value1 = "value1";
bytes memory value2 = "value2";
bytes[] memory expected = new bytes[](2);
expected[0] = value1;
expected[1] = value2;

assertTrue(s_set.add(value1));
assertTrue(s_set.add(value2));
assertEq(s_set.length(), 2);
assertTrue(s_set.contains(value1));
assertTrue(s_set.contains(value2));
assertEq(s_set.at(0), value1);
assertEq(s_set.at(1), value2);
_assertBytesArrayEq(s_set.values(), expected);
}

function testFuzz_add(bytes[2] memory values) public {
bytes[] memory expected = new bytes[](values.length);

for (uint256 i = 0; i < values.length; ++i) {
// Ensure uniqueness
expected[i] = bytes.concat(values[i], abi.encodePacked(i));
s_set.add(expected[i]);

assertEq(s_set.at(i), expected[i]);
assertTrue(s_set.contains(expected[i]));
}

assertEq(s_set.length(), values.length);
_assertBytesArrayEq(s_set.values(), expected);
}
}

contract EnumerableBytesSet_Remove is EnumerableBytesSetTest {
using EnumerableBytesSet for EnumerableBytesSet.BytesSet;

EnumerableBytesSet.BytesSet private s_set;

function setUp() public {
s_set.add("value1");
s_set.add("value2");
}

function test_remove_SingleExistingValue() public {
bytes memory value = "value1";
bytes[] memory expected = new bytes[](1);
expected[0] = "value2";

assertTrue(s_set.remove(value));
assertEq(s_set.length(), 1);
assertFalse(s_set.contains(value));
assertEq(s_set.at(0), "value2");
_assertBytesArrayEq(s_set.values(), expected);
}

function test_remove_MultipleExistingValues() public {
bytes memory value1 = "value1";
bytes memory value2 = "value2";
bytes[] memory expected = new bytes[](0);

vm.expectRevert();
assertEq(s_set.at(0), "");
vm.expectRevert();
assertEq(s_set.at(1), "");

assertTrue(s_set.remove(value1));
assertTrue(s_set.remove(value2));
assertEq(s_set.length(), 0);
assertFalse(s_set.contains(value1));
assertFalse(s_set.contains(value2));
_assertBytesArrayEq(s_set.values(), expected);
}

function test_remove_SingleNonExistingValue() public {
bytes memory value = "value3";
bytes[] memory expected = new bytes[](2);
expected[0] = "value1";
expected[1] = "value2";

assertFalse(s_set.remove(value));
assertEq(s_set.length(), 2);
assertFalse(s_set.contains(value));
assertEq(s_set.at(0), "value1");
assertEq(s_set.at(1), "value2");
_assertBytesArrayEq(s_set.values(), expected);
}
}

0 comments on commit 5fbc875

Please sign in to comment.