diff --git a/src/OmniCounter.sol b/src/OmniCounter.sol new file mode 100644 index 0000000..16d59a2 --- /dev/null +++ b/src/OmniCounter.sol @@ -0,0 +1,277 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.19; + +import {OptionsBuilder} from "@layerzerolabs/lz-evm-oapp-v2/contracts/oapp/libs/OptionsBuilder.sol"; +import { + MessagingParams, + MessagingReceipt +} from "@layerzerolabs/lz-evm-protocol-v2/contracts/interfaces/ILayerZeroEndpointV2.sol"; +import {UUPSUpgradeable} from "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol"; +import {OAppUpgradeable, MessagingFee, Origin} from "@zodomo/oapp-upgradeable/OAppUpgradeable.sol"; +import {ILayerZeroComposer} from "@layerzerolabs/lz-evm-protocol-v2/contracts/interfaces/ILayerZeroComposer.sol"; + +library MsgCodec { + uint8 internal constant VANILLA_TYPE = 1; + uint8 internal constant COMPOSED_TYPE = 2; + uint8 internal constant ABA_TYPE = 3; + uint8 internal constant COMPOSED_ABA_TYPE = 4; + + uint8 internal constant MSG_TYPE_OFFSET = 0; + uint8 internal constant SRC_EID_OFFSET = 1; + uint8 internal constant VALUE_OFFSET = 5; + + function encode(uint8 _type, uint32 _srcEid) internal pure returns (bytes memory) { + return abi.encodePacked(_type, _srcEid); + } + + function encode(uint8 _type, uint32 _srcEid, uint256 _value) internal pure returns (bytes memory) { + return abi.encodePacked(_type, _srcEid, _value); + } + + function msgType(bytes calldata _message) internal pure returns (uint8) { + return uint8(bytes1(_message[MSG_TYPE_OFFSET:SRC_EID_OFFSET])); + } + + function srcEid(bytes calldata _message) internal pure returns (uint32) { + return uint32(bytes4(_message[SRC_EID_OFFSET:VALUE_OFFSET])); + } + + function value(bytes calldata _message) internal pure returns (uint256) { + return uint256(bytes32(_message[VALUE_OFFSET:])); + } +} + +contract OmniCounter is ILayerZeroComposer, OAppUpgradeable, UUPSUpgradeable { + using MsgCodec for bytes; + using OptionsBuilder for bytes; + + uint256 public count; + uint256 public composedCount; + + address public admin; + uint32 public eid; + + mapping(uint32 srcEid => mapping(bytes32 sender => uint64 nonce)) private maxReceivedNonce; + bool private orderedNonce; + + // for global assertions + mapping(uint32 srcEid => uint256 count) public inboundCount; + mapping(uint32 dstEid => uint256 count) public outboundCount; + + constructor() { + _disableInitializers(); + } + + /** + * @dev Initialize the OApp with the provided endpoint and owner. + * @param _endpoint The address of the LOCAL LayerZero endpoint. + * @param _owner The address of the owner of the OApp. + */ + function initialize(address _endpoint, address _owner) public initializer { + _initializeOApp(_endpoint, _owner); + } + + modifier onlyAdmin() { + require(msg.sender == admin, "only admin"); + _; + } + + // ------------------------------- + // Only Admin + function setAdmin(address _admin) external onlyAdmin { + admin = _admin; + } + + function withdraw(address payable _to, uint256 _amount) external onlyAdmin { + (bool success,) = _to.call{value: _amount}(""); + require(success, "OmniCounter: withdraw failed"); + } + + // ------------------------------- + // Send + function increment(uint32 _eid, uint8 _type, bytes calldata _options) external payable { + // bytes memory options = combineOptions(_eid, _type, _options); + _lzSend(_eid, MsgCodec.encode(_type, eid), _options, MessagingFee(msg.value, 0), payable(msg.sender)); + _incrementOutbound(_eid); + } + + // this is a broken function to skip incrementing outbound count + // so that preCrime will fail + function brokenIncrement(uint32 _eid, uint8 _type, bytes calldata _options) external payable onlyAdmin { + // bytes memory options = combineOptions(_eid, _type, _options); + _lzSend(_eid, MsgCodec.encode(_type, eid), _options, MessagingFee(msg.value, 0), payable(msg.sender)); + } + + function batchIncrement(uint32[] calldata _eids, uint8[] calldata _types, bytes[] calldata _options) + external + payable + { + require(_eids.length == _options.length && _eids.length == _types.length, "OmniCounter: length mismatch"); + + MessagingReceipt memory receipt; + uint256 providedFee = msg.value; + for (uint256 i = 0; i < _eids.length; i++) { + address refundAddress = i == _eids.length - 1 ? msg.sender : address(this); + uint32 dstEid = _eids[i]; + uint8 msgType = _types[i]; + // bytes memory options = combineOptions(dstEid, msgType, _options[i]); + receipt = _lzSend( + dstEid, MsgCodec.encode(msgType, eid), _options[i], MessagingFee(providedFee, 0), payable(refundAddress) + ); + _incrementOutbound(dstEid); + providedFee -= receipt.fee.nativeFee; + } + } + + // ------------------------------- + // View + function quote(uint32 _eid, uint8 _type, bytes calldata _options) + public + view + returns (uint256 nativeFee, uint256 lzTokenFee) + { + // bytes memory options = combineOptions(_eid, _type, _options); + MessagingFee memory fee = _quote(_eid, MsgCodec.encode(_type, eid), _options, false); + return (fee.nativeFee, fee.lzTokenFee); + } + + // ------------------------------- + function _lzReceive( + Origin calldata _origin, + bytes32 _guid, + bytes calldata _message, + address, /*_executor*/ + bytes calldata /*_extraData*/ + ) internal override { + _acceptNonce(_origin.srcEid, _origin.sender, _origin.nonce); + uint8 messageType = _message.msgType(); + + if (messageType == MsgCodec.VANILLA_TYPE) { + count++; + + //////////////////////////////// IMPORTANT ////////////////////////////////// + /// if you request for msg.value in the options, you should also encode it + /// into your message and check the value received at destination (example below). + /// if not, the executor could potentially provide less msg.value than you requested + /// leading to unintended behavior. Another option is to assert the executor to be + /// one that you trust. + ///////////////////////////////////////////////////////////////////////////// + require(msg.value >= _message.value(), "OmniCounter: insufficient value"); + + _incrementInbound(_origin.srcEid); + } else if (messageType == MsgCodec.COMPOSED_TYPE || messageType == MsgCodec.COMPOSED_ABA_TYPE) { + count++; + _incrementInbound(_origin.srcEid); + endpoint.sendCompose(address(this), _guid, 0, _message); + } else if (messageType == MsgCodec.ABA_TYPE) { + count++; + _incrementInbound(_origin.srcEid); + + // send back to the sender + _incrementOutbound(_origin.srcEid); + bytes memory options = OptionsBuilder.newOptions().addExecutorLzReceiveOption(200000, 10); + _lzSend( + _origin.srcEid, + MsgCodec.encode(MsgCodec.VANILLA_TYPE, eid, 10), + options, + MessagingFee(msg.value, 0), + payable(address(this)) + ); + } else { + revert("invalid message type"); + } + } + + function _incrementInbound(uint32 _srcEid) internal { + inboundCount[_srcEid]++; + } + + function _incrementOutbound(uint32 _dstEid) internal { + outboundCount[_dstEid]++; + } + + function lzCompose(address _oApp, bytes32, /*_guid*/ bytes calldata _message, address, bytes calldata) + external + payable + override + { + require(_oApp == address(this), "!oApp"); + require(msg.sender == address(endpoint), "!endpoint"); + + uint8 msgType = _message.msgType(); + if (msgType == MsgCodec.COMPOSED_TYPE) { + composedCount += 1; + } else if (msgType == MsgCodec.COMPOSED_ABA_TYPE) { + composedCount += 1; + + uint32 srcEid = _message.srcEid(); + _incrementOutbound(srcEid); + bytes memory options = OptionsBuilder.newOptions().addExecutorLzReceiveOption(200000, 0); + _lzSend( + srcEid, + MsgCodec.encode(MsgCodec.VANILLA_TYPE, eid), + options, + MessagingFee(msg.value, 0), + payable(address(this)) + ); + } else { + revert("invalid message type"); + } + } + + // ------------------------------- + // Ordered OApp + // this demonstrates how to build an app that requires execution nonce ordering + // normally an app should decide ordered or not on contract construction + // this is just a demo + function setOrderedNonce(bool _orderedNonce) external onlyOwner { + orderedNonce = _orderedNonce; + } + + function _acceptNonce(uint32 _srcEid, bytes32 _sender, uint64 _nonce) internal virtual { + uint64 currentNonce = maxReceivedNonce[_srcEid][_sender]; + if (orderedNonce) { + require(_nonce == currentNonce + 1, "OApp: invalid nonce"); + } + // update the max nonce anyway. once the ordered mode is turned on, missing early nonces will be rejected + if (_nonce > currentNonce) { + maxReceivedNonce[_srcEid][_sender] = _nonce; + } + } + + function nextNonce(uint32 _srcEid, bytes32 _sender) public view virtual override returns (uint64) { + if (orderedNonce) { + return maxReceivedNonce[_srcEid][_sender] + 1; + } else { + return 0; // path nonce starts from 1. if 0 it means that there is no specific nonce enforcement + } + } + + // TODO should override oApp version with added ordered nonce increment + // a governance function to skip nonce + function skipInboundNonce(uint32 _srcEid, bytes32 _sender, uint64 _nonce) public virtual onlyOwner { + endpoint.skip(address(this), _srcEid, _sender, _nonce); + if (orderedNonce) { + maxReceivedNonce[_srcEid][_sender]++; + } + } + + // @dev Batch send requires overriding this function from OAppSender because the msg.value contains multiple fees + function _payNative(uint256 _nativeFee) internal virtual override returns (uint256 nativeFee) { + if (msg.value < _nativeFee) revert NotEnoughNative(msg.value); + return _nativeFee; + } + + // be able to receive ether + receive() external payable virtual {} + + fallback() external payable {} + + /* ========== UUPS ========== */ + //solhint-disable-next-line no-empty-blocks + function _authorizeUpgrade(address) internal override onlyOwner {} + + function getImplementation() external view returns (address) { + return _getImplementation(); + } +} diff --git a/test/Counter.t.sol b/test/Counter.t.sol index 66112d5..69f3bd7 100644 --- a/test/Counter.t.sol +++ b/test/Counter.t.sol @@ -23,7 +23,7 @@ contract CounterTest is ProxyTestHelper { setUpEndpoints(2, LibraryType.UltraLightNode); - (address[] memory uas,) = setupOAppsProxies(1, 2); + (address[] memory uas,) = setupOAppsProxies(type(Counter).creationCode, 1, 2); aCounter = Counter(payable(uas[0])); bCounter = Counter(payable(uas[1])); } @@ -59,4 +59,15 @@ contract CounterTest is ProxyTestHelper { assertEq(aCounter.count(), counterBefore + 1, "increment assertion failure"); } + + // required for test helper to know how to initialize the OApp + function _deployOAppProxy(address _endpoint, address _owner, address implementationAddress) + internal + override + returns (address proxyAddress) + { + UUPSProxy proxy = + new UUPSProxy(implementationAddress, abi.encodeWithSelector(Counter.initialize.selector, _endpoint, _owner)); + proxyAddress = address(proxy); + } } diff --git a/test/CounterUpgradeability.t.sol b/test/CounterUpgradeability.t.sol index cbb80f5..9259b23 100644 --- a/test/CounterUpgradeability.t.sol +++ b/test/CounterUpgradeability.t.sol @@ -28,7 +28,7 @@ contract CounterUpgradeabilityTest is ProxyTestHelper { setUpEndpoints(2, LibraryType.UltraLightNode); - (address[] memory uas, address implementationAddress) = setupOAppsProxies(1, 2); + (address[] memory uas, address implementationAddress) = setupOAppsProxies(type(Counter).creationCode, 1, 2); counterImplementation = Counter(implementationAddress); @@ -96,4 +96,15 @@ contract CounterUpgradeabilityTest is ProxyTestHelper { bCounter.increment{value: nativeFee}(aEid, options); verifyPackets(aEid, addressToBytes32(address(counter))); } + + // required for test helper to know how to initialize the OApp + function _deployOAppProxy(address _endpoint, address _owner, address implementationAddress) + internal + override + returns (address proxyAddress) + { + UUPSProxy proxy = + new UUPSProxy(implementationAddress, abi.encodeWithSelector(Counter.initialize.selector, _endpoint, _owner)); + proxyAddress = address(proxy); + } } diff --git a/test/OmniCounter.t.sol b/test/OmniCounter.t.sol new file mode 100644 index 0000000..ebe5a4b --- /dev/null +++ b/test/OmniCounter.t.sol @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.19; + +import {Test, console2} from "forge-std/Test.sol"; + +import {OptionsBuilder} from "@layerzerolabs/lz-evm-oapp-v2/contracts/oapp/libs/OptionsBuilder.sol"; + +import {OmniCounter, MsgCodec} from "../src/OmniCounter.sol"; +import {UUPSProxy} from "../src/UUPSProxy.sol"; +import {ProxyTestHelper} from "./utils/ProxyTestHelper.sol"; + +contract CounterTest is ProxyTestHelper { + using OptionsBuilder for bytes; + + uint32 aEid = 1; + uint32 bEid = 2; + + OmniCounter public aCounter; + OmniCounter public bCounter; + + function setUp() public virtual override { + super.setUp(); + + setUpEndpoints(2, LibraryType.UltraLightNode); + + (address[] memory uas,) = setupOAppsProxies(type(OmniCounter).creationCode, 1, 2); + aCounter = OmniCounter(payable(uas[0])); + bCounter = OmniCounter(payable(uas[1])); + } + + // classic message passing A -> B + function test_increment() public { + uint256 counterBefore = bCounter.count(); + + bytes memory options = OptionsBuilder.newOptions().addExecutorLzReceiveOption(200000, 0); + (uint256 nativeFee,) = aCounter.quote(bEid, MsgCodec.VANILLA_TYPE, options); + aCounter.increment{value: nativeFee}(bEid, MsgCodec.VANILLA_TYPE, options); + + assertEq(bCounter.count(), counterBefore, "shouldn't be increased until packet is verified"); + + // verify packet to bCounter manually + verifyPackets(bEid, addressToBytes32(address(bCounter))); + + assertEq(bCounter.count(), counterBefore + 1, "increment assertion failure"); + } + + function test_batchIncrement() public { + uint256 counterBefore = bCounter.count(); + + uint256 batchSize = 5; + uint32[] memory eids = new uint32[](batchSize); + uint8[] memory types = new uint8[](batchSize); + bytes[] memory options = new bytes[](batchSize); + bytes memory option = OptionsBuilder.newOptions().addExecutorLzReceiveOption(200000, 0); + uint256 fee; + for (uint256 i = 0; i < batchSize; i++) { + eids[i] = bEid; + types[i] = MsgCodec.VANILLA_TYPE; + options[i] = option; + (uint256 nativeFee,) = aCounter.quote(eids[i], types[i], options[i]); + fee += nativeFee; + } + + vm.expectRevert(); // Errors.InvalidAmount + aCounter.batchIncrement{value: fee - 1}(eids, types, options); + + aCounter.batchIncrement{value: fee}(eids, types, options); + verifyPackets(bEid, addressToBytes32(address(bCounter))); + + assertEq(bCounter.count(), counterBefore + batchSize, "batchIncrement assertion failure"); + } + + // classic message passing A -> B1 -> B2 + function test_lzCompose_increment() public { + uint256 countBefore = bCounter.count(); + uint256 composedCountBefore = bCounter.composedCount(); + + bytes memory options = + OptionsBuilder.newOptions().addExecutorLzReceiveOption(200000, 0).addExecutorLzComposeOption(0, 200000, 0); + (uint256 nativeFee,) = aCounter.quote(bEid, MsgCodec.COMPOSED_TYPE, options); + aCounter.increment{value: nativeFee}(bEid, MsgCodec.COMPOSED_TYPE, options); + + verifyPackets(bEid, addressToBytes32(address(bCounter)), 0, address(bCounter)); + + assertEq(bCounter.count(), countBefore + 1, "increment B1 assertion failure"); + assertEq(bCounter.composedCount(), composedCountBefore + 1, "increment B2 assertion failure"); + } + + // A -> B -> A + function test_ABA_increment() public { + uint256 countABefore = aCounter.count(); + uint256 countBBefore = bCounter.count(); + + bytes memory options = OptionsBuilder.newOptions().addExecutorLzReceiveOption(10000000, 10000000); + (uint256 nativeFee,) = aCounter.quote(bEid, MsgCodec.ABA_TYPE, options); + aCounter.increment{value: nativeFee}(bEid, MsgCodec.ABA_TYPE, options); + + verifyPackets(bEid, addressToBytes32(address(bCounter))); + assertEq(aCounter.count(), countABefore, "increment A assertion failure"); + assertEq(bCounter.count(), countBBefore + 1, "increment B assertion failure"); + + verifyPackets(aEid, addressToBytes32(address(aCounter))); + assertEq(aCounter.count(), countABefore + 1, "increment A assertion failure"); + } + + // required for test helper to know how to initialize the OApp + function _deployOAppProxy(address _endpoint, address _owner, address implementationAddress) + internal + override + returns (address proxyAddress) + { + UUPSProxy proxy = new UUPSProxy( + implementationAddress, abi.encodeWithSelector(OmniCounter.initialize.selector, _endpoint, _owner) + ); + proxyAddress = address(proxy); + } +} diff --git a/test/utils/ProxyTestHelper.sol b/test/utils/ProxyTestHelper.sol index 5ca8afa..2a47265 100644 --- a/test/utils/ProxyTestHelper.sol +++ b/test/utils/ProxyTestHelper.sol @@ -6,10 +6,9 @@ import {Test, console2} from "forge-std/Test.sol"; import {OptionsBuilder} from "@layerzerolabs/lz-evm-oapp-v2/contracts/oapp/libs/OptionsBuilder.sol"; import {TestHelper} from "@layerzerolabs/lz-evm-oapp-v2/test/TestHelper.sol"; -import {Counter} from "../../src/Counter.sol"; import {UUPSProxy} from "../../src/UUPSProxy.sol"; -contract ProxyTestHelper is TestHelper { +abstract contract ProxyTestHelper is TestHelper { using OptionsBuilder for bytes; function setUp() public virtual override {} @@ -17,12 +16,16 @@ contract ProxyTestHelper is TestHelper { /** * @dev setup UAs, only if the UA has `endpoint` address as the unique parameter */ - function setupOAppsProxies(uint8 _startEid, uint8 _oappNum) + function setupOAppsProxies(bytes memory _oappCreationCode, uint8 _startEid, uint8 _oappNum) public returns (address[] memory oapps, address implementationAddress) { - Counter counter = new Counter(); - implementationAddress = address(counter); + implementationAddress = address(0); + + assembly { + implementationAddress := create(0, add(_oappCreationCode, 0x20), mload(_oappCreationCode)) + if iszero(extcodesize(implementationAddress)) { revert(0, 0) } + } oapps = new address[](_oappNum); for (uint8 eid = _startEid; eid < _startEid + _oappNum; eid++) { @@ -35,10 +38,6 @@ contract ProxyTestHelper is TestHelper { function _deployOAppProxy(address _endpoint, address _owner, address implementationAddress) internal - returns (address proxyAddress) - { - UUPSProxy proxy = - new UUPSProxy(implementationAddress, abi.encodeWithSelector(Counter.initialize.selector, _endpoint, _owner)); - proxyAddress = address(proxy); - } + virtual + returns (address proxyAddress); }