Skip to content

Commit

Permalink
feat(RewardsStreamerMP): stream rewards for a period without checking…
Browse files Browse the repository at this point in the history
… a real reward token balance
  • Loading branch information
gravityblast committed Nov 22, 2024
1 parent 4432937 commit ca9b6f1
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 65 deletions.
95 changes: 56 additions & 39 deletions src/RewardsStreamerMP.sol
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ contract RewardsStreamerMP is IStakeManager, TrustedCodehashAccess, ReentrancyGu
error StakingManager__DurationCannotBeZero();

IERC20 public immutable STAKING_TOKEN;
IERC20 public immutable REWARD_TOKEN;

uint256 public constant SCALE_FACTOR = 1e18;
uint256 public constant MP_RATE_PER_YEAR = 1e18;
Expand All @@ -32,8 +31,10 @@ contract RewardsStreamerMP is IStakeManager, TrustedCodehashAccess, ReentrancyGu
uint256 public rewardIndex;
uint256 public lastMPUpdatedTime;

uint256 public rewardsPerSecond;
uint256 public totalRewardsAccrued;
uint256 public rewardAmount;
uint256 public lastRewardTime;
uint256 public rewardStartTime;
uint256 public rewardEndTime;

struct Account {
Expand All @@ -47,9 +48,8 @@ contract RewardsStreamerMP is IStakeManager, TrustedCodehashAccess, ReentrancyGu

mapping(address account => Account data) public accounts;

constructor(address _owner, address _stakingToken, address _rewardToken) TrustedCodehashAccess(_owner) {
constructor(address _owner, address _stakingToken) TrustedCodehashAccess(_owner) {
STAKING_TOKEN = IERC20(_stakingToken);
REWARD_TOKEN = IERC20(_rewardToken);
lastMPUpdatedTime = block.timestamp;
}

Expand All @@ -70,11 +70,6 @@ contract RewardsStreamerMP is IStakeManager, TrustedCodehashAccess, ReentrancyGu
revert StakingManager__CannotRestakeWithLockedFunds();
}

uint256 accountRewards = calculateAccountRewards(msg.sender);
if (accountRewards > 0) {
distributeRewards(msg.sender, accountRewards);
}

account.stakedBalance += amount;
totalStaked += amount;

Expand Down Expand Up @@ -116,11 +111,6 @@ contract RewardsStreamerMP is IStakeManager, TrustedCodehashAccess, ReentrancyGu
_updateGlobalState();
_updateAccountMP(msg.sender);

uint256 accountRewards = calculateAccountRewards(msg.sender);
if (accountRewards > 0) {
distributeRewards(msg.sender, accountRewards);
}

uint256 previousStakedBalance = account.stakedBalance;

uint256 mpToReduce = (account.accountMP * amount * SCALE_FACTOR) / (previousStakedBalance * SCALE_FACTOR);
Expand Down Expand Up @@ -165,6 +155,11 @@ contract RewardsStreamerMP is IStakeManager, TrustedCodehashAccess, ReentrancyGu
// Adjust rewardIndex before updating totalMP
uint256 previousTotalWeight = totalStaked + totalMP;
totalMP += accruedMP;

// FIXME: If newTotalWeight > previousTotalWeight, then (previousTotalWeight) / (newTotalWeight) is less than 1.
// Multiplying rewardIndex by a fraction less than 1 reduces its value.
// Users who have not updated their accountRewardIndex may see a reduction in their pending rewards.
// Possible fix: adjust only when newTotalWeight < previousTotalWeight
uint256 newTotalWeight = totalStaked + totalMP;

if (previousTotalWeight != 0 && newTotalWeight != previousTotalWeight) {
Expand All @@ -183,12 +178,40 @@ contract RewardsStreamerMP is IStakeManager, TrustedCodehashAccess, ReentrancyGu
revert StakingManager__AmountCannotBeZero();
}

// this will call _updateRewardIndex and update the totalRewardsAccrued
_updateGlobalState();

rewardsPerSecond = amount / duration;

rewardEndTime = block.timestamp + duration;
// in case _updateRewardIndex returns earlier,
// we still update the lastRewardTime
lastRewardTime = block.timestamp;
rewardAmount = amount;
rewardStartTime = block.timestamp;
rewardEndTime = block.timestamp + duration;
}

function _calculateAccruedRewards() internal view returns (uint256) {
if (rewardEndTime <= rewardStartTime) {
// No active reward period
return 0;
}

uint256 currentTime = block.timestamp < rewardEndTime ? block.timestamp : rewardEndTime;

if (currentTime <= lastRewardTime) {
// No new rewards have accrued since lastRewardTime
return 0;
}

uint256 timeElapsed = currentTime - lastRewardTime;
uint256 duration = rewardEndTime - rewardStartTime;

if (duration == 0) {
// Prevent division by zero
return 0;
}

uint256 accruedRewards = (timeElapsed * rewardAmount) / duration;
return accruedRewards;
}

function updateRewardIndex() internal {
Expand All @@ -205,13 +228,14 @@ contract RewardsStreamerMP is IStakeManager, TrustedCodehashAccess, ReentrancyGu
return;
}

uint256 newRewards = rewardsPerSecond * elapsedTime;

if (newRewards > 0) {
rewardIndex += (newRewards * SCALE_FACTOR) / totalWeight;
uint256 newRewards = _calculateAccruedRewards();
if (newRewards == 0) {
return;
}

lastRewardTime = applicableTime;
totalRewardsAccrued += newRewards;
rewardIndex += (newRewards * SCALE_FACTOR) / totalWeight;
lastRewardTime = block.timestamp < rewardEndTime ? block.timestamp : rewardEndTime;
}

function _updateAccountMP(address accountAddress) internal {
Expand Down Expand Up @@ -243,33 +267,26 @@ contract RewardsStreamerMP is IStakeManager, TrustedCodehashAccess, ReentrancyGu

function calculateAccountRewards(address accountAddress) public view returns (uint256) {
Account storage account = accounts[accountAddress];

uint256 accountWeight = account.stakedBalance + account.accountMP;
uint256 deltaRewardIndex = rewardIndex - account.accountRewardIndex;
return (accountWeight * deltaRewardIndex) / SCALE_FACTOR;
}

function distributeRewards(address to, uint256 amount) internal {
uint256 rewardBalance = REWARD_TOKEN.balanceOf(address(this));
// If amount is higher than the contract's balance (for rounding error), transfer the balance.
if (amount > rewardBalance) {
amount = rewardBalance;
}

bool success = REWARD_TOKEN.transfer(to, amount);
if (!success) {
revert StakingManager__TransferFailed();
}
return (accountWeight * deltaRewardIndex) / SCALE_FACTOR;
}

function getStakedBalance(address accountAddress) external view returns (uint256) {
return accounts[accountAddress].stakedBalance;
}

function getPendingRewards(address accountAddress) external view returns (uint256) {
return calculateAccountRewards(accountAddress);
}

function getAccount(address accountAddress) external view returns (Account memory) {
return accounts[accountAddress];
}

function totalRewardsSupply() public view returns (uint256) {
return totalRewardsAccrued + _calculateAccruedRewards();
}

function rewardsBalanceOf(address accountAddress) external view returns (uint256) {
return calculateAccountRewards(accountAddress);
}
}
1 change: 0 additions & 1 deletion src/interfaces/IStakeManager.sol
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ interface IStakeManager is ITrustedCodehashAccess {
function getStakedBalance(address _vault) external view returns (uint256 _balance);

function STAKING_TOKEN() external view returns (IERC20);
function REWARD_TOKEN() external view returns (IERC20);
function MIN_LOCKUP_PERIOD() external view returns (uint256);
function MAX_LOCKUP_PERIOD() external view returns (uint256);
function MP_RATE_PER_YEAR() external view returns (uint256);
Expand Down
113 changes: 88 additions & 25 deletions test/RewardsStreamerMP.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
pragma solidity ^0.8.26;

import { Ownable } from "@openzeppelin/contracts/access/Ownable.sol";
import { Test } from "forge-std/Test.sol";
import { Test, console } from "forge-std/Test.sol";
import { RewardsStreamerMP } from "../src/RewardsStreamerMP.sol";
import { StakeVault } from "../src/StakeVault.sol";
import { MockToken } from "./mocks/MockToken.sol";

contract RewardsStreamerMPTest is Test {
MockToken rewardToken;
MockToken stakingToken;
RewardsStreamerMP public streamer;

Expand All @@ -21,10 +20,9 @@ contract RewardsStreamerMPTest is Test {
mapping(address owner => address vault) public vaults;

function setUp() public virtual {
rewardToken = new MockToken("Reward Token", "RT");
stakingToken = new MockToken("Staking Token", "ST");

streamer = new RewardsStreamerMP(admin, address(stakingToken), address(rewardToken));
streamer = new RewardsStreamerMP(admin, address(stakingToken));

address[4] memory accounts = [alice, bob, charlie, dave];
for (uint256 i = 0; i < accounts.length; i++) {
Expand All @@ -38,10 +36,6 @@ contract RewardsStreamerMPTest is Test {
vm.prank(accounts[i]);
stakingToken.approve(address(vault), 10_000e18);
}

rewardToken.mint(admin, 10_000e18);
vm.prank(admin);
rewardToken.approve(address(streamer), 10_000e18);
}

struct CheckStreamerParams {
Expand Down Expand Up @@ -104,12 +98,6 @@ contract RewardsStreamerMPTest is Test {
vault.unstake(amount);
}

function _addReward(uint256 amount) public {
vm.prank(admin);
rewardToken.transfer(address(streamer), amount);
streamer.updateGlobalState();
}

function _calculateBonusMP(uint256 amount, uint256 lockupTime) public view returns (uint256) {
return amount
* (lockupTime * streamer.MAX_MULTIPLIER() * streamer.SCALE_FACTOR() / streamer.MAX_LOCKUP_PERIOD())
Expand Down Expand Up @@ -212,7 +200,7 @@ contract IntegrationTest is RewardsStreamerMPTest {

// T3
vm.prank(admin);
rewardToken.transfer(address(streamer), 1000e18);
// rewardToken.transfer(address(streamer), 1000e18);
streamer.updateGlobalState();

checkStreamer(
Expand Down Expand Up @@ -351,7 +339,7 @@ contract IntegrationTest is RewardsStreamerMPTest {

// T6
vm.prank(admin);
rewardToken.transfer(address(streamer), 1000e18);
// rewardToken.transfer(address(streamer), 1000e18);
streamer.updateGlobalState();

checkStreamer(
Expand Down Expand Up @@ -511,9 +499,6 @@ contract StakeTest is RewardsStreamerMPTest {
})
);

// 1000 rewards generated
_addReward(1000e18);

checkStreamer(
CheckStreamerParams({
totalStaked: 10e18,
Expand Down Expand Up @@ -824,8 +809,6 @@ contract StakeTest is RewardsStreamerMPTest {
maxMP: 150e18
})
);
// 1000 rewards generated
_addReward(1000e18);

checkStreamer(
CheckStreamerParams({
Expand Down Expand Up @@ -1417,17 +1400,21 @@ contract RewardsStreamerMP_RewardsTest is RewardsStreamerMPTest {
}

function testSetRewards() public {
assertEq(streamer.rewardsPerSecond(), 0);
assertEq(streamer.lastRewardTime(), 0);
assertEq(streamer.rewardStartTime(), 0);
assertEq(streamer.rewardEndTime(), 0);
assertEq(streamer.lastRewardTime(), 0);

uint256 currentTime = vm.getBlockTimestamp();
// just to be sure that currentTime is not 0
// since we are testing that it is used for rewardStartTime
currentTime += 1 days;
vm.warp(currentTime);
vm.prank(admin);
streamer.setReward(1000, 10);

assertEq(streamer.rewardsPerSecond(), 100);
assertEq(streamer.lastRewardTime(), currentTime);
assertEq(streamer.rewardStartTime(), currentTime);
assertEq(streamer.rewardEndTime(), currentTime + 10);
assertEq(streamer.lastRewardTime(), currentTime);
}

function testSetRewards_RevertsNotAuthorized() public {
Expand All @@ -1447,4 +1434,80 @@ contract RewardsStreamerMP_RewardsTest is RewardsStreamerMPTest {
vm.expectRevert(RewardsStreamerMP.StakingManager__AmountCannotBeZero.selector);
streamer.setReward(0, 10);
}

function testTotalRewardsSupply() public {
_stake(alice, 100e18, 0);
assertEq(streamer.totalRewardsSupply(), 0);

uint256 initialTime = vm.getBlockTimestamp();

vm.prank(admin);
streamer.setReward(1000e18, 10 days);
assertEq(streamer.totalRewardsSupply(), 0);

for (uint256 i = 0; i <= 10; i++) {
vm.warp(initialTime + i * 1 days);
assertEq(streamer.totalRewardsSupply(), 100e18 * i);
}

// after the end of the reward period, the total rewards supply does not increase
vm.warp(initialTime + 11 days);
assertEq(streamer.totalRewardsSupply(), 1000e18);
assertEq(streamer.totalRewardsAccrued(), 0);

uint256 secondRewardTime = initialTime + 20 days;
vm.warp(secondRewardTime);

// still the same rewards supply after 20 days
assertEq(streamer.totalRewardsSupply(), 1000e18);
assertEq(streamer.totalRewardsAccrued(), 0);

// set other 2000 rewards for other 10 days
vm.prank(admin);
streamer.setReward(2000e18, 10 days);
// accrued is 1000 from the previous reward and still 0 for the new one
assertEq(streamer.totalRewardsSupply(), 1000e18, "totalRewardsSupply should be 1000");
assertEq(streamer.totalRewardsAccrued(), 1000e18);

uint256 previousSupply = 1000e18;
for (uint256 i = 0; i <= 10; i++) {
vm.warp(secondRewardTime + i * 1 days);
assertEq(streamer.totalRewardsSupply(), previousSupply + 200e18 * i);
}
}

function testRewardsBalanceOf() public {
assertEq(streamer.totalRewardsSupply(), 0);

vm.warp(0);

uint256 initialTime = vm.getBlockTimestamp();

_stake(alice, 100e18, 0);
assertEq(streamer.rewardsBalanceOf(vaults[alice]), 0);

vm.prank(admin);
streamer.setReward(1000e18, 10 days);
assertEq(streamer.rewardsBalanceOf(vaults[alice]), 0);

vm.warp(initialTime + 1 days);

// FIXME: this is needed to update the global state and account MP
// Later we should update the functions to use "real-time" values.
streamer.updateGlobalState();
streamer.updateAccountMP(vaults[alice]);

uint256 tolerance = 300; // 300 wei

assertEq(streamer.totalRewardsSupply(), 100e18, "Total rewards supply mismatch");
assertApproxEqAbs(streamer.rewardsBalanceOf(vaults[alice]), 100e18, tolerance);

vm.warp(initialTime + 10 days);

streamer.updateGlobalState();
streamer.updateAccountMP(vaults[alice]);

assertEq(streamer.totalRewardsSupply(), 1000e18, "Total rewards supply mismatch");
assertApproxEqAbs(streamer.rewardsBalanceOf(vaults[alice]), 1000e18, tolerance);
}
}

0 comments on commit ca9b6f1

Please sign in to comment.