diff --git a/src/RewardsStreamerMP.sol b/src/RewardsStreamerMP.sol index daa4bf4..096036b 100644 --- a/src/RewardsStreamerMP.sol +++ b/src/RewardsStreamerMP.sol @@ -407,9 +407,22 @@ contract RewardsStreamerMP is } function updateRewardIndex() internal { + uint256 accruedRewards; + uint256 newRewardIndex; + + (accruedRewards, newRewardIndex) = _pendingRewardIndex(); + totalRewardsAccrued += accruedRewards; + + if (newRewardIndex > rewardIndex) { + rewardIndex = newRewardIndex; + lastRewardTime = block.timestamp < rewardEndTime ? block.timestamp : rewardEndTime; + } + } + + function _pendingRewardIndex() internal view returns (uint256, uint256) { uint256 totalWeight = totalStaked + totalMPAccrued; if (totalWeight == 0) { - return; + return (0, rewardIndex); } uint256 currentTime = block.timestamp; @@ -417,20 +430,17 @@ contract RewardsStreamerMP is uint256 elapsedTime = applicableTime - lastRewardTime; if (elapsedTime == 0) { - return; + return (0, rewardIndex); } - uint256 newRewards = _calculatePendingRewards(); - if (newRewards == 0) { - return; + uint256 accruedRewards = _calculatePendingRewards(); + if (accruedRewards == 0) { + return (0, rewardIndex); } - totalRewardsAccrued += newRewards; - uint256 indexIncrease = Math.mulDiv(newRewards, SCALE_FACTOR, totalWeight); - if (indexIncrease > 0) { - rewardIndex += indexIncrease; - lastRewardTime = block.timestamp < rewardEndTime ? block.timestamp : rewardEndTime; - } + uint256 newRewardIndex = rewardIndex + Math.mulDiv(accruedRewards, SCALE_FACTOR, totalWeight); + + return (accruedRewards, newRewardIndex); } function _calculateBonusMP(uint256 amount, uint256 lockPeriod) internal pure returns (uint256) { @@ -496,6 +506,23 @@ contract RewardsStreamerMP is } function rewardsBalanceOf(address accountAddress) external view returns (uint256) { - return calculateAccountRewards(accountAddress); + uint256 newRewardIndex; + (, newRewardIndex) = _pendingRewardIndex(); + + Account storage account = accounts[accountAddress]; + + uint256 accountWeight = account.stakedBalance + _mpBalanceOf(accountAddress); + uint256 deltaRewardIndex = newRewardIndex - account.accountRewardIndex; + + return (accountWeight * deltaRewardIndex) / SCALE_FACTOR; + } + + function _mpBalanceOf(address accountAddress) internal view returns (uint256) { + Account storage account = accounts[accountAddress]; + return account.mpAccrued + _getAccountPendingdMP(account); + } + + function mpBalanceOf(address accountAddress) external view returns (uint256) { + return _mpBalanceOf(accountAddress); } } diff --git a/test/RewardsStreamerMP.t.sol b/test/RewardsStreamerMP.t.sol index be21ba6..9abfff6 100644 --- a/test/RewardsStreamerMP.t.sol +++ b/test/RewardsStreamerMP.t.sol @@ -833,7 +833,7 @@ contract StakeTest is RewardsStreamerMPTest { ); uint256 currentTime = vm.getBlockTimestamp(); - uint256 timeToMaxMP = _calculateTimeToAccureMP(stakeAmount, totalMaxMP - totalMP); + uint256 timeToMaxMP = _calculateTimeToAccureMP(stakeAmount, totalMaxMP - totalMPAccrued); vm.warp(currentTime + timeToMaxMP); streamer.updateGlobalState();