diff --git a/src/core/Permit2Payment.sol b/src/core/Permit2Payment.sol index cc403a497..973c62a96 100644 --- a/src/core/Permit2Payment.sol +++ b/src/core/Permit2Payment.sol @@ -87,6 +87,30 @@ library TransientStorage { } } + function setCallback(function (bytes calldata) internal returns (bytes memory) callback) internal { + assembly ("memory-safe") { + let operator := tload(_OPERATOR_SLOT) + if shr(0xa0, operator) { + mstore(0x00, 0x77f94425) // selector for `ReentrantCallback(address)` + mstore(0x00, and(0xffffffffffffffffffffffffffffffffffffffff, operator)) + revert(0x1c, 0x24) + } + tstore(_OPERATOR_SLOT, or(shl(0xa0, callback), operator)) + } + } + + error CallbackNotSpent(uint256 callbackInt); + + function checkSpentCallback() internal view { + uint256 callbackInt; + assembly ("memory-safe") { + callbackInt := shr(0xa0, tload(_OPERATOR_SLOT)) + } + if (callbackInt != 0) { + revert CallbackNotSpent(callbackInt); + } + } + function getAndClearCallback() internal returns (function (bytes calldata) internal returns (bytes memory) callback) @@ -94,12 +118,7 @@ library TransientStorage { assembly ("memory-safe") { let operator := tload(_OPERATOR_SLOT) callback := shr(0xa0, operator) - operator := and(0xffffffffffffffffffffffffffffffffffffffff, operator) - if iszero(eq(operator, caller())) { - mstore(0x00, 0xe758b8d5) // selector for ConfusedDeputy() - revert(0x1c, 0x04) - } - tstore(_OPERATOR_SLOT, operator) + tstore(_OPERATOR_SLOT, and(0xffffffffffffffffffffffffffffffffffffffff, operator)) } } @@ -179,6 +198,27 @@ abstract contract Permit2PaymentBase is AllowanceHolderContext, SettlerAbstract return _setOperatorAndCall(payable(target), 0, data, callback); } + function _setCallbackAndCall( + address payable target, + uint256 value, + bytes memory data, + function (bytes calldata) internal returns (bytes memory) callback + ) internal override returns (bytes memory) { + TransientStorage.setCallback(callback); + (bool success, bytes memory returndata) = target.call{value: value}(data); + success.maybeRevert(returndata); + TransientStorage.checkSpentCallback(); + return returndata; + } + + function _setCallbackAndCall( + address target, + bytes memory data, + function (bytes calldata) internal returns (bytes memory) callback + ) internal override returns (bytes memory) { + return _setCallbackAndCall(payable(target), 0, data, callback); + } + modifier metaTx(address msgSender, bytes32 witness) override { if (_isForwarded()) { revert ConfusedDeputy(); diff --git a/src/core/Permit2PaymentAbstract.sol b/src/core/Permit2PaymentAbstract.sol index a1092072f..3dbdbc597 100644 --- a/src/core/Permit2PaymentAbstract.sol +++ b/src/core/Permit2PaymentAbstract.sol @@ -63,5 +63,18 @@ abstract contract Permit2PaymentAbstract is AbstractContext { function (bytes calldata) internal returns (bytes memory) callback ) internal virtual returns (bytes memory); + function _setCallbackAndCall( + address payable target, + uint256 value, + bytes memory data, + function (bytes calldata) internal returns (bytes memory) callback + ) internal virtual returns (bytes memory); + + function _setCallbackAndCall( + address target, + bytes memory data, + function (bytes calldata) internal returns (bytes memory) callback + ) internal virtual returns (bytes memory); + modifier metaTx(address msgSender, bytes32 witness) virtual; } diff --git a/src/core/UniswapV3.sol b/src/core/UniswapV3.sol index 8c5ae1e4e..fbf165816 100644 --- a/src/core/UniswapV3.sol +++ b/src/core/UniswapV3.sol @@ -170,13 +170,23 @@ abstract contract UniswapV3 is SettlerAbstract { int256 amount0; int256 amount1; if (payer == address(this)) { - (amount0, amount1) = pool.swap( - // Intermediate tokens go to this contract. - isPathMultiHop ? address(this) : recipient, - zeroForOne, - int256(sellAmount), - zeroForOne ? MIN_PRICE_SQRT_RATIO + 1 : MAX_PRICE_SQRT_RATIO - 1, - swapCallbackData + (amount0, amount1) = abi.decode( + _setCallbackAndCall( + address(pool), + abi.encodeCall( + pool.swap, + ( + // Intermediate tokens go to this contract. + isPathMultiHop ? address(this) : recipient, + zeroForOne, + int256(sellAmount), + zeroForOne ? MIN_PRICE_SQRT_RATIO + 1 : MAX_PRICE_SQRT_RATIO - 1, + swapCallbackData + ) + ), + _uniV3Callback + ), + (int256, int256) ); } else { (amount0, amount1) = abi.decode( @@ -332,9 +342,11 @@ abstract contract UniswapV3 is SettlerAbstract { function _uniV3Callback(bytes calldata data) private returns (bytes memory) { require(data.length >= 0x84 && bytes4(data) == _UNIV3_CALLBACK_SELECTOR); - int256 amount0Delta = int256(uint256(bytes32(data[0x04:]))); - int256 amount1Delta = int256(uint256(bytes32(data[0x24:]))); + int256 amount0Delta; + int256 amount1Delta; assembly ("memory-safe") { + amount0Delta := calldataload(0x04) + amount1Delta := calldataload(0x24) data.offset := add(0x04, calldataload(add(0x44, data.offset))) data.length := calldataload(data.offset) data.offset := add(0x20, data.offset)