Skip to content

Commit

Permalink
feat: non reverting callback modules (#65)
Browse files Browse the repository at this point in the history
* feat: replace external contract call with low level call

* feat: add MockFailCallback contract

a mock callback that always reverts

* test: add callback and multiplecallback module integration test

* feat: add missing `s` in module's variable name

* test: add missing unit test to complete coverage

* test: add unit test for reverting callback

* fix: apply linter suggestion

* test: actually mock the callback to revert

* chore: move common setup to an internal fn

* test: update multiple callbacks tests
  • Loading branch information
xorsal authored Oct 15, 2024
1 parent b88fda6 commit 57f0be5
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 11 deletions.
4 changes: 3 additions & 1 deletion solidity/contracts/modules/finality/CallbackModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ contract CallbackModule is Module, ICallbackModule {
) external override(Module, ICallbackModule) onlyOracle {
RequestParameters memory _params = decodeRequestData(_request.finalityModuleData);

IProphetCallback(_params.target).prophetCallback(_params.data);
// purposely skips the return data, so we don't care if the call succeeds or fails
_params.target.call(abi.encodeCall(IProphetCallback.prophetCallback, (_params.data)));

emit Callback(_response.requestId, _params.target, _params.data);
emit RequestFinalized(_response.requestId, _response, _finalizer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ contract MultipleCallbacksModule is Module, IMultipleCallbacksModule {
uint256 _length = _params.targets.length;

for (uint256 _i; _i < _length;) {
IProphetCallback(_params.targets[_i]).prophetCallback(_params.data[_i]);
// purposely skips the return data, so we don't care if the call succeeds or fails
_params.targets[_i].call(abi.encodeCall(IProphetCallback.prophetCallback, (_params.data[_i])));
emit Callback(_response.requestId, _params.targets[_i], _params.data[_i]);
unchecked {
++_i;
Expand Down
60 changes: 60 additions & 0 deletions solidity/test/integration/CallbackModule.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.19;

import './IntegrationBase.sol';

contract Integration_CallbackModule is IntegrationBase {
IProphetCallback public callback;

bytes32 internal _requestId;
bytes internal _expectedData = bytes('a-well-formed-calldata');

function setUp() public override {
super.setUp();

callback = new MockCallback();

mockRequest.finalityModuleData =
abi.encode(ICallbackModule.RequestParameters({target: address(callback), data: _expectedData}));
}

function test_finalizeExecutesCallback() public {
_setupRequest();

vm.expectCall(address(callback), abi.encodeCall(IProphetCallback.prophetCallback, (_expectedData)));

// advance time past deadline
vm.warp(block.timestamp + _expectedDeadline + _baseDisputeWindow);
oracle.finalize(mockRequest, mockResponse);
}

function test_callbacksNeverRevert() public {
MockFailCallback _target = new MockFailCallback();
mockRequest.finalityModuleData =
abi.encode(ICallbackModule.RequestParameters({target: address(_target), data: _expectedData}));
_setupRequest();

// expect call to target passing the expected data
vm.expectCall(address(_target), abi.encodeCall(IProphetCallback.prophetCallback, (_expectedData)));

vm.warp(block.timestamp + _expectedDeadline + _baseDisputeWindow);
oracle.finalize(mockRequest, mockResponse);
}

function _setupRequest() internal {
_resetMockIds();

_deposit(_accountingExtension, requester, usdc, _expectedReward);
vm.startPrank(requester);
_accountingExtension.approveModule(address(_requestModule));
_requestId = oracle.createRequest(mockRequest, _ipfsHash);
vm.stopPrank();

_deposit(_accountingExtension, proposer, usdc, _expectedBondSize);
vm.startPrank(proposer);
_accountingExtension.approveModule(address(_responseModule));
mockResponse.response = abi.encode(proposer, _requestId, bytes(''));
oracle.proposeResponse(mockRequest, mockResponse);
vm.stopPrank();
}
}
8 changes: 8 additions & 0 deletions solidity/test/integration/IntegrationBase.sol
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ import {
IRootVerificationModule, RootVerificationModule
} from '../../contracts/modules/dispute/RootVerificationModule.sol';
import {CallbackModule, ICallbackModule} from '../../contracts/modules/finality/CallbackModule.sol';
import {
IMultipleCallbacksModule,
MultipleCallbacksModule
} from '../../contracts/modules/finality/MultipleCallbacksModule.sol';

import {HttpRequestModule, IHttpRequestModule} from '../../contracts/modules/request/HttpRequestModule.sol';
import {
ISparseMerkleTreeRequestModule,
Expand All @@ -37,6 +42,9 @@ import {ArbitratorModule, IArbitratorModule} from '../../contracts/modules/resol
import {BondedResponseModule, IBondedResponseModule} from '../../contracts/modules/response/BondedResponseModule.sol';
import {SparseMerkleTreeL32Verifier} from '../../contracts/periphery/SparseMerkleTreeL32Verifier.sol';

import {IProphetVerifier} from '../../interfaces/IProphetVerifier.sol';
import {MockCallback, MockFailCallback} from '../mocks/MockCallback.sol';

import {IArbitrator} from '../../interfaces/IArbitrator.sol';
import {IProphetCallback} from '../../interfaces/IProphetCallback.sol';
import {ITreeVerifier} from '../../interfaces/ITreeVerifier.sol';
Expand Down
84 changes: 84 additions & 0 deletions solidity/test/integration/MultipleCallbacksModule.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.19;

import './IntegrationBase.sol';

contract Integration_MultipleCallbackModule is IntegrationBase {
uint256 public constant CALLBACKS_AMOUNT = 255;
IProphetCallback public callback;
MultipleCallbacksModule public multipleCallbacksModule;

bytes32 internal _requestId;

function setUp() public override {
super.setUp();

multipleCallbacksModule = new MultipleCallbacksModule(oracle);
mockRequest.finalityModule = address(multipleCallbacksModule);

callback = new MockCallback();
}

function test_finalizeExecutesCallback() public {
(address[] memory _targets, bytes[] memory _datas) = _createCallbacksData(address(callback), CALLBACKS_AMOUNT);
mockRequest.finalityModuleData =
abi.encode(IMultipleCallbacksModule.RequestParameters({targets: _targets, data: _datas}));

_setupRequest();

for (uint256 _i; _i < _datas.length; _i++) {
vm.expectCall(address(_targets[_i]), abi.encodeCall(IProphetCallback.prophetCallback, (_datas[_i])));
}

vm.warp(block.timestamp + _expectedDeadline + _baseDisputeWindow);
oracle.finalize(mockRequest, mockResponse);
}

function test_callbacksNeverRevert() public {
callback = new MockFailCallback();

(address[] memory _targets, bytes[] memory _datas) = _createCallbacksData(address(callback), CALLBACKS_AMOUNT);

mockRequest.finalityModuleData =
abi.encode(IMultipleCallbacksModule.RequestParameters({targets: _targets, data: _datas}));

_setupRequest();

// expect call to every target with the expected data
for (uint256 _i; _i < _datas.length; _i++) {
vm.expectCall(address(_targets[_i]), abi.encodeCall(IProphetCallback.prophetCallback, (_datas[_i])));
}

vm.warp(block.timestamp + _expectedDeadline + _baseDisputeWindow);
oracle.finalize(mockRequest, mockResponse);
}

function _setupRequest() internal {
_resetMockIds();

_deposit(_accountingExtension, requester, usdc, _expectedReward);
vm.startPrank(requester);
_accountingExtension.approveModule(address(_requestModule));
_requestId = oracle.createRequest(mockRequest, _ipfsHash);
vm.stopPrank();

_deposit(_accountingExtension, proposer, usdc, _expectedBondSize);
vm.startPrank(proposer);
_accountingExtension.approveModule(address(_responseModule));
mockResponse.response = abi.encode(proposer, _requestId, bytes(''));
oracle.proposeResponse(mockRequest, mockResponse);
vm.stopPrank();
}

function _createCallbacksData(
address _target,
uint256 _length
) internal returns (address[] memory _targets, bytes[] memory _datas) {
_targets = new address[](_length);
_datas = new bytes[](_length);
for (uint256 _i; _i < _length; _i++) {
_targets[_i] = _target;
_datas[_i] = abi.encode(keccak256(abi.encode(_i)));
}
}
}
9 changes: 9 additions & 0 deletions solidity/test/mocks/MockCallback.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,12 @@ contract MockCallback is IProphetCallback {
_callResponse = abi.encode(true);
}
}

contract MockFailCallback is IProphetCallback {
error MockFailCallback_Fail();

function prophetCallback(bytes calldata /* _callData */ ) external pure returns (bytes memory _callResponse) {
_callResponse = abi.encode(false);
revert MockFailCallback_Fail();
}
}
18 changes: 18 additions & 0 deletions solidity/test/unit/modules/finality/CallbackModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,24 @@ contract CallbackModule_Unit_FinalizeRequest is BaseTest {
callbackModule.finalizeRequest(mockRequest, mockResponse, _proposer);
}

function test_finalizationSucceedsWhenCallbackReverts(
address _proposer,
address _target,
bytes calldata _data
) public assumeFuzzable(_target) {
mockRequest.finalityModuleData = abi.encode(ICallbackModule.RequestParameters({target: _target, data: _data}));
mockResponse.requestId = _getId(mockRequest);

// Mock revert and expect the callback
vm.mockCallRevert(
_target, abi.encodeWithSelector(IProphetCallback.prophetCallback.selector, _data), abi.encode('err')
);
vm.expectCall(_target, abi.encodeWithSelector(IProphetCallback.prophetCallback.selector, _data));

vm.prank(address(oracle));
callbackModule.finalizeRequest(mockRequest, mockResponse, _proposer);
}

/**
* @notice Test that the finalizeRequest reverts if caller is not the oracle
*/
Expand Down
72 changes: 63 additions & 9 deletions solidity/test/unit/modules/finality/MultipleCallbacksModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import {IProphetCallback} from '../../../../interfaces/IProphetCallback.sol';

contract BaseTest is Test, Helpers {
// The target contract
MultipleCallbacksModule public multipleCallbackModule;
MultipleCallbacksModule public multipleCallbacksModule;
// A mock oracle
IOracle public oracle;

Expand All @@ -31,7 +31,7 @@ contract BaseTest is Test, Helpers {
oracle = IOracle(makeAddr('Oracle'));
vm.etch(address(oracle), hex'069420');

multipleCallbackModule = new MultipleCallbacksModule(oracle);
multipleCallbacksModule = new MultipleCallbacksModule(oracle);
}

function targetHasBytecode(address _target) public view returns (bool _hasBytecode) {
Expand All @@ -51,7 +51,7 @@ contract MultipleCallbacksModule_Unit_ModuleData is BaseTest {
* @notice Test that the moduleName function returns the correct name
*/
function test_moduleNameReturnsName() public view {
assertEq(multipleCallbackModule.moduleName(), 'MultipleCallbacksModule');
assertEq(multipleCallbacksModule.moduleName(), 'MultipleCallbacksModule');
}

/**
Expand All @@ -72,9 +72,9 @@ contract MultipleCallbacksModule_Unit_ModuleData is BaseTest {
}

if (!_valid) {
assertFalse(multipleCallbackModule.validateParameters(abi.encode(_params)));
assertFalse(multipleCallbacksModule.validateParameters(abi.encode(_params)));
} else {
assertTrue(multipleCallbackModule.validateParameters(abi.encode(_params)));
assertTrue(multipleCallbacksModule.validateParameters(abi.encode(_params)));
}
}
}
Expand Down Expand Up @@ -109,16 +109,58 @@ contract MultipleCallbacksModule_Unit_FinalizeRequests is BaseTest {
);

// Check: is the event emitted?
vm.expectEmit(true, true, true, true, address(multipleCallbackModule));
vm.expectEmit(true, true, true, true, address(multipleCallbacksModule));
emit Callback(_requestId, _target, _calldata);
}

// Check: is the event emitted?
vm.expectEmit(true, true, true, true, address(multipleCallbackModule));
vm.expectEmit(true, true, true, true, address(multipleCallbacksModule));
emit RequestFinalized(_requestId, mockResponse, address(oracle));

vm.prank(address(oracle));
multipleCallbackModule.finalizeRequest(mockRequest, mockResponse, address(oracle));
multipleCallbacksModule.finalizeRequest(mockRequest, mockResponse, address(oracle));
}

function test_finalizationSucceedsWhenCallbacksRevert(
address[10] calldata _fuzzedTargets,
bytes[10] calldata _fuzzedData
) public {
address[] memory _targets = new address[](_fuzzedTargets.length);
bytes[] memory _data = new bytes[](_fuzzedTargets.length);

// Copying the values to fresh arrays that we can use to build `RequestParameters`
for (uint256 _i; _i < _fuzzedTargets.length; _i++) {
_targets[_i] = _fuzzedTargets[_i];
_data[_i] = _fuzzedData[_i];
}

mockRequest.finalityModuleData =
abi.encode(IMultipleCallbacksModule.RequestParameters({targets: _targets, data: _data}));
bytes32 _requestId = _getId(mockRequest);
mockResponse.requestId = _requestId;

for (uint256 _i; _i < _targets.length; _i++) {
address _target = _targets[_i];
bytes memory _calldata = _data[_i];

// Skip precompiles, VM, console.log addresses, etc
_assumeFuzzable(_target);
vm.mockCallRevert(
_target, abi.encodeWithSelector(IProphetCallback.prophetCallback.selector, _calldata), abi.encode('err')
);
vm.expectCall(_target, abi.encodeWithSelector(IProphetCallback.prophetCallback.selector, _calldata));

// Check: is the event emitted?
vm.expectEmit(true, true, true, true, address(multipleCallbacksModule));
emit Callback(_requestId, _target, _calldata);
}

// Check: is the event emitted?
vm.expectEmit(true, true, true, true, address(multipleCallbacksModule));
emit RequestFinalized(_requestId, mockResponse, address(oracle));

vm.prank(address(oracle));
multipleCallbacksModule.finalizeRequest(mockRequest, mockResponse, address(oracle));
}

/**
Expand All @@ -130,6 +172,18 @@ contract MultipleCallbacksModule_Unit_FinalizeRequests is BaseTest {
// Check: does it revert if not called by the Oracle?
vm.expectRevert(IModule.Module_OnlyOracle.selector);
vm.prank(_caller);
multipleCallbackModule.finalizeRequest(_request, mockResponse, address(_caller));
multipleCallbacksModule.finalizeRequest(_request, mockResponse, address(_caller));
}

function test_decodeRequestData(address[] memory _targets, bytes[] memory _data) public view {
// Create and set some mock request data
bytes memory _requestData = abi.encode(IMultipleCallbacksModule.RequestParameters({targets: _targets, data: _data}));

// Decode the given request data
IMultipleCallbacksModule.RequestParameters memory _params = multipleCallbacksModule.decodeRequestData(_requestData);

// Check: decoded values match original values?
assertEq(_params.targets, _targets);
assertEq(_params.data, _data);
}
}

0 comments on commit 57f0be5

Please sign in to comment.