diff --git a/contracts/src/SuccinctGateway.sol b/contracts/src/SuccinctGateway.sol index 78cc97f4..54ca4495 100644 --- a/contracts/src/SuccinctGateway.sol +++ b/contracts/src/SuccinctGateway.sol @@ -314,7 +314,7 @@ contract SuccinctGateway is ISuccinctGateway, FunctionRegistry, TimelockedUpgrad if (msg.sender != verifierOwners[_functionId]) { revert NotFunctionOwner(msg.sender, verifierOwners[_functionId]); } - allowedProvers[_functionId][_prover] = false; + delete allowedProvers[_functionId][_prover]; emit ProverUpdated(_functionId, _prover, false); } @@ -328,7 +328,7 @@ contract SuccinctGateway is ISuccinctGateway, FunctionRegistry, TimelockedUpgrad /// @notice Remove a default prover. /// @param _prover The address of the prover to remove. function removeDefaultProver(address _prover) external onlyGuardian { - allowedProvers[bytes32(0)][_prover] = false; + delete allowedProvers[bytes32(0)][_prover]; emit ProverUpdated(bytes32(0), _prover, false); } @@ -339,6 +339,16 @@ contract SuccinctGateway is ISuccinctGateway, FunctionRegistry, TimelockedUpgrad feeVault = _feeVault; } + /// @notice Recovers stuck ETH from the contract. + /// @param _to The address to send the ETH to. + /// @param _amount The wei amount of ETH to send. + function recover(address _to, uint256 _amount) external onlyGuardian { + (bool success,) = _to.call{value: _amount}(""); + if (!success) { + revert RecoverFailed(); + } + } + /// @dev Computes a unique identifier for a request. /// @param _functionId The function identifier. /// @param _inputHash The hash of the function input. diff --git a/contracts/src/interfaces/ISuccinctGateway.sol b/contracts/src/interfaces/ISuccinctGateway.sol index deb850f8..29d65d7c 100644 --- a/contracts/src/interfaces/ISuccinctGateway.sol +++ b/contracts/src/interfaces/ISuccinctGateway.sol @@ -44,6 +44,7 @@ interface ISuccinctGatewayErrors { error InvalidProof(address verifier, bytes32 inputHash, bytes32 outputHash, bytes proof); error ReentrantFulfill(); error OnlyProver(bytes32 functionId, address sender); + error RecoverFailed(); } interface ISuccinctGateway is ISuccinctGatewayEvents, ISuccinctGatewayErrors { diff --git a/contracts/test/SuccinctGateway.t.sol b/contracts/test/SuccinctGateway.t.sol index 6c84b533..9fa1e2ce 100644 --- a/contracts/test/SuccinctGateway.t.sol +++ b/contracts/test/SuccinctGateway.t.sol @@ -162,7 +162,7 @@ contract RequestTest is SuccinctGatewayTest { bytes32 functionId = TestConsumer(consumer).FUNCTION_ID(); address callbackAddress = consumer; bytes4 callbackSelector = TestConsumer.handleCallback.selector; - uint32 callbackGasLimit = TestConsumer(consumer).CALLBACK_GAS_LIMIT(); + uint32 callbackGasLimit = DEFAULT_GAS_LIMIT; uint256 fee = DEFAULT_FEE; bytes memory context = abi.encode(nonce); bytes memory output = OUTPUT; @@ -181,7 +181,7 @@ contract RequestTest is SuccinctGatewayTest { fee ); vm.prank(sender); - TestConsumer(consumer).requestCallback{value: fee}(); + TestConsumer(consumer).requestCallback{value: fee}(DEFAULT_GAS_LIMIT); assertEq(prevNonce + 1, SuccinctGateway(gateway).nonce()); assertEq(TestConsumer(consumer).handledRequests(0), false); @@ -219,7 +219,7 @@ contract RequestTest is SuccinctGatewayTest { bytes32 functionId = TestConsumer(consumer).FUNCTION_ID(); address callbackAddress = consumer; bytes4 callbackSelector = TestConsumer.handleCallback.selector; - uint32 callbackGasLimit = TestConsumer(consumer).CALLBACK_GAS_LIMIT(); + uint32 callbackGasLimit = DEFAULT_GAS_LIMIT; uint256 fee = DEFAULT_FEE; bytes memory context = abi.encode(nonce); bytes memory output = OUTPUT; @@ -238,7 +238,7 @@ contract RequestTest is SuccinctGatewayTest { fee ); vm.prank(sender); - TestConsumer(consumer).requestCallback{value: fee}(); + TestConsumer(consumer).requestCallback{value: fee}(callbackGasLimit); assertEq(prevNonce + 1, SuccinctGateway(gateway).nonce()); assertEq(TestConsumer(consumer).handledRequests(0), false); @@ -366,6 +366,12 @@ contract RequestTest is SuccinctGatewayTest { ); assertEq(TestConsumer(consumer).handledRequests(0), true); + + // Recover ETH + vm.prank(guardian); + SuccinctGateway(gateway).recover(guardian, fee); + + assertEq(guardian.balance, fee); } function test_Callback_WhenCallbackGasLimitTooLow() public { @@ -859,7 +865,7 @@ contract FunctionRegistryTest is function test_RevertDeployAndRegisterFunction_WhenAlreadyRegistered() public { // Deploy verifier and register function - (bytes32 functionId1,) = IFunctionRegistry(gateway).deployAndRegisterFunction( + IFunctionRegistry(gateway).deployAndRegisterFunction( owner, type(TestFunctionVerifier1).creationCode, "test-verifier1" ); @@ -1139,3 +1145,26 @@ contract SetFeeVaultTest is SuccinctGatewayTest { SuccinctGateway(gateway).setFeeVault(newFeeVault); } } + +contract RecoverTest is SuccinctGatewayTest { + function test_Recover() public { + uint256 fee = DEFAULT_FEE; + vm.deal(gateway, fee); + + // Recover ETH + vm.prank(guardian); + SuccinctGateway(gateway).recover(guardian, fee); + + assertEq(guardian.balance, fee); + } + + function test_RevertRecover_WhenNotGuardian() public { + uint256 fee = DEFAULT_FEE; + vm.deal(gateway, fee); + + // Recover ETH + vm.expectRevert(abi.encodeWithSignature("OnlyGuardian(address)", sender)); + vm.prank(sender); + SuccinctGateway(gateway).recover(guardian, fee); + } +}