diff --git a/src/FlywheelCore.sol b/src/FlywheelCore.sol index 593219b..eeb8bca 100644 --- a/src/FlywheelCore.sol +++ b/src/FlywheelCore.sol @@ -40,6 +40,8 @@ contract FlywheelCore is Auth { /// @notice optional booster module for calculating virtual balances on strategies IFlywheelBooster public flywheelBooster; + error InvalidAddress(); + constructor( ERC20 _rewardToken, IFlywheelRewards _flywheelRewards, @@ -163,9 +165,12 @@ contract FlywheelCore is Auth { /// @notice swap out the flywheel rewards contract function setFlywheelRewards(IFlywheelRewards newFlywheelRewards) external requiresAuth { - uint256 oldRewardBalance = rewardToken.balanceOf(address(flywheelRewards)); - if (oldRewardBalance > 0) { - rewardToken.safeTransferFrom(address(flywheelRewards), address(newFlywheelRewards), oldRewardBalance); + if (address(newFlywheelRewards) == address(0)) revert InvalidAddress(); + if (address(flywheelRewards) != address(0)) { + uint256 oldRewardBalance = rewardToken.balanceOf(address(flywheelRewards)); + + if (oldRewardBalance != 0) + rewardToken.safeTransferFrom(address(flywheelRewards), address(newFlywheelRewards), oldRewardBalance); } flywheelRewards = newFlywheelRewards;