diff --git a/src/modules/SafeEmailRecoveryModule.sol b/src/modules/SafeEmailRecoveryModule.sol index e428873..c4dfe9d 100644 --- a/src/modules/SafeEmailRecoveryModule.sol +++ b/src/modules/SafeEmailRecoveryModule.sol @@ -13,6 +13,7 @@ contract SafeEmailRecoveryModule is EmailRecoveryManager { event RecoveryExecuted(address indexed account); + error InvalidAccount(address account); error InvalidSelector(bytes4 selector); error RecoveryFailed(address account); @@ -37,20 +38,23 @@ contract SafeEmailRecoveryModule is EmailRecoveryManager { } /** - * @notice Executes recovery on a Safe account. Must be called by the trusted recovery manager + * @notice Executes recovery on a Safe account. Called from the recovery manager * @param account The account to execute recovery for - * @param recoveryData The recovery calldata that should be executed on the Safe - * being recovered + * @param recoveryData The recovery data that should be executed on the Safe + * being recovered. recoveryData = abi.encode(safeAccount, recoveryFunctionCalldata) */ function recover(address account, bytes calldata recoveryData) internal override { - (, bytes memory recoveryCalldata) = abi.decode(recoveryData, (address, bytes)); - // FIXME: What if you use this module with a different subject handler? It could chose - // not to encode the account/validator along with the calldata + (address encodedAccount, bytes memory recoveryCalldata) = + abi.decode(recoveryData, (address, bytes)); + + if (encodedAccount == address(0) || encodedAccount != account) { + revert InvalidAccount(encodedAccount); + } + bytes4 calldataSelector; assembly { calldataSelector := mload(add(recoveryCalldata, 32)) } - if (calldataSelector != selector) { revert InvalidSelector(calldataSelector); } diff --git a/test/integration/OwnableValidatorRecovery/EmailRecoveryModule/EmailRecoveryModule.t.sol b/test/integration/OwnableValidatorRecovery/EmailRecoveryModule/EmailRecoveryModule.t.sol index 16705d1..c63804b 100644 --- a/test/integration/OwnableValidatorRecovery/EmailRecoveryModule/EmailRecoveryModule.t.sol +++ b/test/integration/OwnableValidatorRecovery/EmailRecoveryModule/EmailRecoveryModule.t.sol @@ -67,6 +67,7 @@ contract OwnableValidatorRecovery_EmailRecoveryModule_Integration_Test is assertEq(recoveryRequest.executeAfter, 0); assertEq(recoveryRequest.executeBefore, 0); assertEq(recoveryRequest.currentWeight, 1); + assertEq(recoveryRequest.recoveryDataHash, recoveryDataHash1); // handle recovery request for guardian 2 uint256 executeAfter = block.timestamp + delay; @@ -76,6 +77,7 @@ contract OwnableValidatorRecovery_EmailRecoveryModule_Integration_Test is assertEq(recoveryRequest.executeAfter, executeAfter); assertEq(recoveryRequest.executeBefore, executeBefore); assertEq(recoveryRequest.currentWeight, 3); + assertEq(recoveryRequest.recoveryDataHash, recoveryDataHash1); // Time travel so that the recovery delay has passed vm.warp(block.timestamp + delay); @@ -89,6 +91,7 @@ contract OwnableValidatorRecovery_EmailRecoveryModule_Integration_Test is assertEq(recoveryRequest.executeAfter, 0); assertEq(recoveryRequest.executeBefore, 0); assertEq(recoveryRequest.currentWeight, 0); + assertEq(recoveryRequest.recoveryDataHash, bytes32(0)); assertEq(updatedOwner, newOwner1); } diff --git a/test/integration/OwnableValidatorRecovery/UniversalEmailRecoveryModule/UniversalEmailRecoveryModule.t.sol b/test/integration/OwnableValidatorRecovery/UniversalEmailRecoveryModule/UniversalEmailRecoveryModule.t.sol index f5eb74c..6fa9462 100644 --- a/test/integration/OwnableValidatorRecovery/UniversalEmailRecoveryModule/UniversalEmailRecoveryModule.t.sol +++ b/test/integration/OwnableValidatorRecovery/UniversalEmailRecoveryModule/UniversalEmailRecoveryModule.t.sol @@ -66,6 +66,7 @@ contract OwnableValidatorRecovery_UniversalEmailRecoveryModule_Integration_Test assertEq(recoveryRequest.executeAfter, 0); assertEq(recoveryRequest.executeBefore, 0); assertEq(recoveryRequest.currentWeight, 1); + assertEq(recoveryRequest.recoveryDataHash, recoveryDataHash1); // handle recovery request for guardian 2 uint256 executeAfter = block.timestamp + delay; @@ -75,6 +76,7 @@ contract OwnableValidatorRecovery_UniversalEmailRecoveryModule_Integration_Test assertEq(recoveryRequest.executeAfter, executeAfter); assertEq(recoveryRequest.executeBefore, executeBefore); assertEq(recoveryRequest.currentWeight, 3); + assertEq(recoveryRequest.recoveryDataHash, recoveryDataHash1); // Time travel so that the recovery delay has passed vm.warp(block.timestamp + delay); @@ -88,6 +90,7 @@ contract OwnableValidatorRecovery_UniversalEmailRecoveryModule_Integration_Test assertEq(recoveryRequest.executeAfter, 0); assertEq(recoveryRequest.executeBefore, 0); assertEq(recoveryRequest.currentWeight, 0); + assertEq(recoveryRequest.recoveryDataHash, bytes32(0)); assertEq(updatedOwner, newOwner1); } diff --git a/test/integration/SafeRecovery/SafeRecoveryNativeModule.t.sol b/test/integration/SafeRecovery/SafeRecoveryNativeModule.t.sol index df30286..6eae5a2 100644 --- a/test/integration/SafeRecovery/SafeRecoveryNativeModule.t.sol +++ b/test/integration/SafeRecovery/SafeRecoveryNativeModule.t.sol @@ -32,7 +32,7 @@ contract SafeRecoveryNativeModule_Integration_Test is SafeNativeIntegrationBase "swapOwner(address,address,address)", address(1), owner, newOwner ); bytes memory recoveryData = abi.encode(safeAddress, recoveryCalldata); - bytes32 calldataHash = keccak256(recoveryData); + bytes32 recoveryDataHash = keccak256(recoveryData); bytes[] memory subjectParamsForRecovery = new bytes[](4); subjectParamsForRecovery[0] = abi.encode(safeAddress); @@ -64,7 +64,10 @@ contract SafeRecoveryNativeModule_Integration_Test is SafeNativeIntegrationBase emailRecoveryModule.handleRecovery(emailAuthMsg, templateIdx); IEmailRecoveryManager.RecoveryRequest memory recoveryRequest = emailRecoveryModule.getRecoveryRequest(safeAddress); + assertEq(recoveryRequest.executeAfter, 0); + assertEq(recoveryRequest.executeBefore, 0); assertEq(recoveryRequest.currentWeight, 1); + assertEq(recoveryRequest.recoveryDataHash, recoveryDataHash); // handle recovery request for guardian 2 uint256 executeAfter = block.timestamp + delay; @@ -75,6 +78,7 @@ contract SafeRecoveryNativeModule_Integration_Test is SafeNativeIntegrationBase assertEq(recoveryRequest.executeAfter, executeAfter); assertEq(recoveryRequest.executeBefore, executeBefore); assertEq(recoveryRequest.currentWeight, 3); + assertEq(recoveryRequest.recoveryDataHash, recoveryDataHash); vm.warp(block.timestamp + delay); @@ -85,6 +89,7 @@ contract SafeRecoveryNativeModule_Integration_Test is SafeNativeIntegrationBase assertEq(recoveryRequest.executeAfter, 0); assertEq(recoveryRequest.executeBefore, 0); assertEq(recoveryRequest.currentWeight, 0); + assertEq(recoveryRequest.recoveryDataHash, bytes32(0)); vm.prank(safeAddress); bool isOwner = Safe(payable(safeAddress)).isOwner(newOwner);