From 7703e8534ded79dbf434fd80b4003d2980f86539 Mon Sep 17 00:00:00 2001 From: Matt Stam Date: Mon, 30 Oct 2023 17:09:18 -0700 Subject: [PATCH] fix: FunctionRegistry bugs + tests (#266) --- contracts/.gitignore | 4 +- contracts/src/FunctionRegistry.sol | 11 +- .../src/interfaces/IFunctionRegistry.sol | 1 + contracts/test/SuccinctGateway.t.sol | 295 +++++++++++++++++- contracts/test/TestUtils.sol | 14 +- 5 files changed, 310 insertions(+), 15 deletions(-) diff --git a/contracts/.gitignore b/contracts/.gitignore index b641e1ff0..84615e6d7 100644 --- a/contracts/.gitignore +++ b/contracts/.gitignore @@ -11,9 +11,7 @@ cache/ out/ # Ignores development broadcast logs -!/broadcast -/broadcast/*/31337/ -/broadcast/**/dry-run/ +/broadcast # Docs docs/ diff --git a/contracts/src/FunctionRegistry.sol b/contracts/src/FunctionRegistry.sol index e18274e6b..cad349597 100644 --- a/contracts/src/FunctionRegistry.sol +++ b/contracts/src/FunctionRegistry.sol @@ -3,7 +3,7 @@ pragma solidity ^0.8.16; import {IFunctionRegistry} from "./interfaces/IFunctionRegistry.sol"; -contract FunctionRegistry is IFunctionRegistry { +abstract contract FunctionRegistry is IFunctionRegistry { /// @dev Maps function identifiers to their corresponding verifiers. mapping(bytes32 => address) public verifiers; @@ -18,7 +18,7 @@ contract FunctionRegistry is IFunctionRegistry { external returns (bytes32 functionId) { - functionId = getFunctionId(msg.sender, _name); + functionId = getFunctionId(_owner, _name); if (address(verifiers[functionId]) != address(0)) { revert FunctionAlreadyRegistered(functionId); // should call update instead } @@ -39,7 +39,7 @@ contract FunctionRegistry is IFunctionRegistry { external returns (bytes32 functionId, address verifier) { - functionId = getFunctionId(msg.sender, _name); + functionId = getFunctionId(_owner, _name); if (address(verifiers[functionId]) != address(0)) { revert FunctionAlreadyRegistered(functionId); // should call update instead } @@ -66,6 +66,9 @@ contract FunctionRegistry is IFunctionRegistry { if (_verifier == address(0)) { revert VerifierCannotBeZero(); } + if (_verifier == verifiers[functionId]) { + revert VerifierAlreadyUpdated(functionId); + } verifiers[functionId] = _verifier; emit FunctionVerifierUpdated(functionId, _verifier); @@ -90,7 +93,7 @@ contract FunctionRegistry is IFunctionRegistry { } /// @notice Returns the functionId for a given owner and function name. - /// @param _owner The owner of the function (sender of registerFunction). + /// @param _owner The owner of the function. /// @param _name The name of the function. function getFunctionId(address _owner, string memory _name) public diff --git a/contracts/src/interfaces/IFunctionRegistry.sol b/contracts/src/interfaces/IFunctionRegistry.sol index 18d8fc7c1..6b00e0014 100644 --- a/contracts/src/interfaces/IFunctionRegistry.sol +++ b/contracts/src/interfaces/IFunctionRegistry.sol @@ -15,6 +15,7 @@ interface IFunctionRegistryErrors { error EmptyBytecode(); error FailedDeploy(); error VerifierCannotBeZero(); + error VerifierAlreadyUpdated(bytes32 functionId); error FunctionAlreadyRegistered(bytes32 functionId); error NotFunctionOwner(address owner, address actualOwner); } diff --git a/contracts/test/SuccinctGateway.t.sol b/contracts/test/SuccinctGateway.t.sol index ad92011c6..d58a565b4 100644 --- a/contracts/test/SuccinctGateway.t.sol +++ b/contracts/test/SuccinctGateway.t.sol @@ -11,10 +11,22 @@ import { ISuccinctGatewayEvents, ISuccinctGatewayErrors } from "src/interfaces/ISuccinctGateway.sol"; -import {IFunctionRegistry} from "src/interfaces/IFunctionRegistry.sol"; -import {TestConsumer, AttackConsumer, TestFunctionVerifier} from "test/TestUtils.sol"; +import { + TestConsumer, + AttackConsumer, + TestFunctionVerifier1, + TestFunctionVerifier2 +} from "test/TestUtils.sol"; +import { + IFunctionRegistry, + IFunctionRegistryEvents, + IFunctionRegistryErrors +} from "src/interfaces/IFunctionRegistry.sol"; +import {TestConsumer, TestFunctionVerifier1} from "test/TestUtils.sol"; import {Proxy} from "src/upgrades/Proxy.sol"; import {SuccinctFeeVault} from "src/payments/SuccinctFeeVault.sol"; +import {AccessControlUpgradeable} from + "@openzeppelin-upgradeable/contracts/access/AccessControlUpgradeable.sol"; contract SuccinctGatewayTest is Test, ISuccinctGatewayEvents, ISuccinctGatewayErrors { // Example Function Request and expected values. @@ -56,7 +68,7 @@ contract SuccinctGatewayTest is Test, ISuccinctGatewayEvents, ISuccinctGatewayEr bytes32 functionId; vm.prank(sender); (functionId, verifier) = IFunctionRegistry(gateway).deployAndRegisterFunction( - owner, type(TestFunctionVerifier).creationCode, "test-verifier" + owner, type(TestFunctionVerifier1).creationCode, "test-verifier" ); // Deploy TestConsumer @@ -66,6 +78,11 @@ contract SuccinctGatewayTest is Test, ISuccinctGatewayEvents, ISuccinctGatewayEr vm.deal(consumer, DEFAULT_FEE); } + function test_SetUp() public { + assertTrue(AccessControlUpgradeable(gateway).hasRole(keccak256("TIMELOCK_ROLE"), timelock)); + assertTrue(AccessControlUpgradeable(gateway).hasRole(keccak256("GUARDIAN_ROLE"), guardian)); + } + function test_Callback() public { uint32 prevNonce = SuccinctGateway(gateway).nonce(); assertEq(prevNonce, 0); @@ -356,7 +373,7 @@ contract SuccinctGatewayTest is Test, ISuccinctGatewayEvents, ISuccinctGatewayEr vm.store(gateway, bytes32(uint256(255)), functionId); vm.store(gateway, bytes32(uint256(256)), inputHash); - // Verifiy call + // Verify call TestConsumer(consumer).verifiedCall(input); } @@ -364,7 +381,7 @@ contract SuccinctGatewayTest is Test, ISuccinctGatewayEvents, ISuccinctGatewayEr bytes memory input = INPUT; bytes32 functionId = TestConsumer(consumer).FUNCTION_ID(); - // Verifiy call + // Verify call vm.expectRevert(abi.encodeWithSelector(InvalidCall.selector, functionId, input)); TestConsumer(consumer).verifiedCall(input); } @@ -411,7 +428,7 @@ contract AttackSuccinctGateway is SuccinctGatewayTest { bytes32 functionId; vm.prank(sender); (functionId, verifier) = IFunctionRegistry(gateway).deployAndRegisterFunction( - owner, type(TestFunctionVerifier).creationCode, "attack-verifier" + owner, type(TestFunctionVerifier1).creationCode, "attack-verifier" ); // Deploy AttackConsumer @@ -518,3 +535,269 @@ contract AttackSuccinctGateway is SuccinctGatewayTest { ); } } + +contract FunctionRegistryTest is + SuccinctGatewayTest, + IFunctionRegistryEvents, + IFunctionRegistryErrors +{ + function test_RegisterFunction() public { + bytes32 expectedFunctionId1 = + IFunctionRegistry(gateway).getFunctionId(owner, "test-verifier1"); + + // Deploy verifier + address verifier1; + bytes memory bytecode = type(TestFunctionVerifier1).creationCode; + bytes32 salt = expectedFunctionId1; + assembly { + verifier1 := create2(0, add(bytecode, 32), mload(bytecode), salt) + } + + // Register function + vm.expectEmit(true, true, true, true, gateway); + emit FunctionRegistered(expectedFunctionId1, verifier1, "test-verifier1", owner); + bytes32 functionId1 = + IFunctionRegistry(gateway).registerFunction(owner, verifier1, "test-verifier1"); + + assertEq(functionId1, expectedFunctionId1); + assertEq(IFunctionRegistry(gateway).verifiers(expectedFunctionId1), verifier1); + assertEq(IFunctionRegistry(gateway).verifierOwners(expectedFunctionId1), owner); + } + + function test_RegisterFunction_WhenOwnerIsSender() public { + bytes32 expectedFunctionId1 = + IFunctionRegistry(gateway).getFunctionId(owner, "test-verifier1"); + + // Deploy verifier + address verifier1; + bytes memory bytecode = type(TestFunctionVerifier1).creationCode; + bytes32 salt = expectedFunctionId1; + assembly { + verifier1 := create2(0, add(bytecode, 32), mload(bytecode), salt) + } + + // Register function + vm.expectEmit(true, true, true, true, gateway); + emit FunctionRegistered(expectedFunctionId1, verifier1, "test-verifier1", owner); + vm.prank(owner); + bytes32 functionId1 = + IFunctionRegistry(gateway).registerFunction(owner, verifier1, "test-verifier1"); + + assertEq(functionId1, expectedFunctionId1); + assertEq(IFunctionRegistry(gateway).verifiers(expectedFunctionId1), verifier1); + assertEq(IFunctionRegistry(gateway).verifierOwners(expectedFunctionId1), owner); + } + + function test_RevertRegisterFunction_WhenAlreadyRegistered() public { + // Deploy verifier + address verifier1; + bytes memory bytecode = type(TestFunctionVerifier1).creationCode; + bytes32 salt = IFunctionRegistry(gateway).getFunctionId(owner, "test-verifier1"); + assembly { + verifier1 := create2(0, add(bytecode, 32), mload(bytecode), salt) + } + + // Register function + vm.expectEmit(true, true, true, true, gateway); + emit FunctionRegistered(salt, verifier1, "test-verifier1", owner); + IFunctionRegistry(gateway).registerFunction(owner, verifier1, "test-verifier1"); + + // Register function again + vm.expectRevert(abi.encodeWithSelector(FunctionAlreadyRegistered.selector, salt)); + IFunctionRegistry(gateway).registerFunction(owner, verifier1, "test-verifier1"); + } + + function test_DeployAndRegisterFunction() public { + bytes32 expectedFunctionId1 = + IFunctionRegistry(gateway).getFunctionId(owner, "test-verifier1"); + + // Deploy verifier and register function + vm.expectEmit(true, false, false, true, gateway); + emit Deployed( + keccak256(type(TestFunctionVerifier1).creationCode), expectedFunctionId1, address(0) + ); + vm.expectEmit(true, true, true, false, gateway); + emit FunctionRegistered(expectedFunctionId1, address(0), "test-verifier1", owner); + (bytes32 functionId1, address verifier1) = IFunctionRegistry(gateway) + .deployAndRegisterFunction( + owner, type(TestFunctionVerifier1).creationCode, "test-verifier1" + ); + + assertEq(functionId1, expectedFunctionId1); + assertEq(IFunctionRegistry(gateway).verifiers(functionId1), verifier1); + assertEq(IFunctionRegistry(gateway).verifierOwners(functionId1), owner); + } + + function test_DeployAndRegisterFunction_WhenOwnerIsSender() public { + bytes32 expectedFunctionId1 = + IFunctionRegistry(gateway).getFunctionId(owner, "test-verifier1"); + + // Deploy verifier and register function + vm.expectEmit(true, false, false, true, gateway); + emit Deployed( + keccak256(type(TestFunctionVerifier1).creationCode), expectedFunctionId1, address(0) + ); + vm.expectEmit(true, true, true, false, gateway); + emit FunctionRegistered(expectedFunctionId1, address(0), "test-verifier1", owner); + vm.prank(owner); + (bytes32 functionId1, address verifier1) = IFunctionRegistry(gateway) + .deployAndRegisterFunction( + owner, type(TestFunctionVerifier1).creationCode, "test-verifier1" + ); + + assertEq(functionId1, expectedFunctionId1); + assertEq(IFunctionRegistry(gateway).verifiers(functionId1), verifier1); + assertEq(IFunctionRegistry(gateway).verifierOwners(functionId1), owner); + } + + function test_RevertDeployAndRegisterFunction_WhenAlreadyRegistered() public { + // Deploy verifier and register function + (bytes32 functionId1,) = IFunctionRegistry(gateway).deployAndRegisterFunction( + owner, type(TestFunctionVerifier1).creationCode, "test-verifier1" + ); + + // Deploy verifier and register function again + vm.expectRevert(abi.encodeWithSelector(FunctionAlreadyRegistered.selector, functionId1)); + IFunctionRegistry(gateway).deployAndRegisterFunction( + owner, type(TestFunctionVerifier1).creationCode, "test-verifier1" + ); + } + + function test_UpdateFunction() public { + bytes32 expectedFunctionId1 = + IFunctionRegistry(gateway).getFunctionId(owner, "test-verifier1"); + + // Deploy verifier and register function + IFunctionRegistry(gateway).deployAndRegisterFunction( + owner, type(TestFunctionVerifier1).creationCode, "test-verifier1" + ); + + // Deploy verifier + address verifier2; + bytes memory bytecode = type(TestFunctionVerifier2).creationCode; + bytes32 salt = expectedFunctionId1; + assembly { + verifier2 := create2(0, add(bytecode, 32), mload(bytecode), salt) + } + + // Update function + vm.expectEmit(true, true, true, true, gateway); + emit FunctionVerifierUpdated(expectedFunctionId1, verifier2); + vm.prank(owner); + bytes32 functionId1 = IFunctionRegistry(gateway).updateFunction(verifier2, "test-verifier1"); + + assertEq(functionId1, expectedFunctionId1); + assertEq(IFunctionRegistry(gateway).verifiers(functionId1), verifier2); + assertEq(IFunctionRegistry(gateway).verifierOwners(functionId1), owner); + } + + function test_RevertUpdateFunction_WhenNotOwner() public { + // Deploy verifier and register function + (bytes32 functionId,) = IFunctionRegistry(gateway).deployAndRegisterFunction( + owner, type(TestFunctionVerifier1).creationCode, "test-verifier1" + ); + + // Deploy verifier + address verifier2; + bytes memory bytecode = type(TestFunctionVerifier2).creationCode; + bytes32 salt = functionId; + assembly { + verifier2 := create2(0, add(bytecode, 32), mload(bytecode), salt) + } + + // Update function + vm.prank(sender); + vm.expectRevert(abi.encodeWithSelector(NotFunctionOwner.selector, sender, address(0))); + IFunctionRegistry(gateway).updateFunction(verifier2, "test-verifier1"); + } + + function test_RevertUpdateFunction_WhenNeverRegistered() public { + // Deploy verifier + address verifier2; + bytes memory bytecode = type(TestFunctionVerifier2).creationCode; + bytes32 salt = bytes32(0); + assembly { + verifier2 := create2(0, add(bytecode, 32), mload(bytecode), salt) + } + + // Update function + vm.expectRevert(abi.encodeWithSelector(NotFunctionOwner.selector, owner, address(0))); + vm.prank(owner); + IFunctionRegistry(gateway).updateFunction(verifier2, "test-verifier1"); + } + + function test_RevertUpdateFunction_WhenVerifierSame() public { + // Deploy verifier and register function + (bytes32 functionId1, address verifier1) = IFunctionRegistry(gateway) + .deployAndRegisterFunction( + owner, type(TestFunctionVerifier1).creationCode, "test-verifier1" + ); + + // Update function + vm.expectRevert(abi.encodeWithSelector(VerifierAlreadyUpdated.selector, functionId1)); + vm.prank(owner); + IFunctionRegistry(gateway).updateFunction(verifier1, "test-verifier1"); + } + + function test_deployAndUpdateFunction() public { + bytes32 expectedFunctionId1 = + IFunctionRegistry(gateway).getFunctionId(owner, "test-verifier1"); + + // Deploy verifier and register function + IFunctionRegistry(gateway).deployAndRegisterFunction( + owner, type(TestFunctionVerifier1).creationCode, "test-verifier1" + ); + + // Deploy verifier and update function + vm.expectEmit(true, false, false, true, gateway); + emit Deployed( + keccak256(type(TestFunctionVerifier2).creationCode), expectedFunctionId1, address(0) + ); + vm.expectEmit(true, true, true, false, gateway); + emit FunctionVerifierUpdated(expectedFunctionId1, address(0)); + vm.prank(owner); + (bytes32 functionId1, address verifier2) = IFunctionRegistry(gateway) + .deployAndUpdateFunction(type(TestFunctionVerifier2).creationCode, "test-verifier1"); + + assertEq(functionId1, expectedFunctionId1); + assertEq(IFunctionRegistry(gateway).verifiers(functionId1), verifier2); + assertEq(IFunctionRegistry(gateway).verifierOwners(functionId1), owner); + } + + function test_RevertDeployAndUpdateFunction_WhenNotOwner() public { + // Deploy verifier and register function + IFunctionRegistry(gateway).deployAndRegisterFunction( + owner, type(TestFunctionVerifier1).creationCode, "test-verifier1" + ); + + // Deploy verifier and update function + vm.prank(sender); + vm.expectRevert(abi.encodeWithSelector(NotFunctionOwner.selector, sender, address(0))); + IFunctionRegistry(gateway).deployAndUpdateFunction( + type(TestFunctionVerifier2).creationCode, "test-verifier1" + ); + } + + function test_RevertDeployAndUpdateFunction_WhenNeverRegistered() public { + // Deploy verifier and update function + vm.expectRevert(abi.encodeWithSelector(NotFunctionOwner.selector, owner, address(0))); + vm.prank(owner); + IFunctionRegistry(gateway).deployAndUpdateFunction( + type(TestFunctionVerifier1).creationCode, "test-verifier1" + ); + } + + function test_RevertDeployAndUpdateFunction_WhenBytecodeSame() public { + // Deploy verifier and register function + IFunctionRegistry(gateway).deployAndRegisterFunction( + owner, type(TestFunctionVerifier1).creationCode, "test-verifier1" + ); + + // Deploy verifier and update function + vm.expectRevert(abi.encodeWithSelector(FailedDeploy.selector)); + vm.prank(owner); + IFunctionRegistry(gateway).deployAndUpdateFunction( + type(TestFunctionVerifier1).creationCode, "test-verifier1" + ); + } +} diff --git a/contracts/test/TestUtils.sol b/contracts/test/TestUtils.sol index 0cf7f9ced..276b55b3d 100644 --- a/contracts/test/TestUtils.sol +++ b/contracts/test/TestUtils.sol @@ -195,9 +195,19 @@ contract AttackConsumer is Test { } } -contract TestFunctionVerifier is IFunctionVerifier { +contract TestFunctionVerifier1 is IFunctionVerifier { function verificationKeyHash() external pure returns (bytes32) { - return keccak256("verificationKeyHash"); + return keccak256("verificationKeyHash1"); + } + + function verify(bytes32, bytes32, bytes memory) external pure returns (bool) { + return true; + } +} + +contract TestFunctionVerifier2 is IFunctionVerifier { + function verificationKeyHash() external pure returns (bytes32) { + return keccak256("verificationKeyHash2"); } function verify(bytes32, bytes32, bytes memory) external pure returns (bool) {