Skip to content

Commit

Permalink
Merge pull request #165 from Popcorn-Limited/feat/multi-strategy-update
Browse files Browse the repository at this point in the history
Feat/multi strategy update
  • Loading branch information
RedVeil authored Mar 15, 2024
2 parents f315b1f + 4191c81 commit 1b19ecd
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 29 deletions.
91 changes: 62 additions & 29 deletions src/vaults/MultiStrategyVault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@ contract MultiStrategyVault is
if (totalSupply() == 0) feesUpdatedAt = block.timestamp;

uint256 feeShares = _convertToShares(
assets.mulDiv(uint256(fees.deposit), 1e18, Math.Rounding.Floor),
Math.Rounding.Floor
assets.mulDiv(uint256(fees.deposit), 1e18, Math.Rounding.Floor),
Math.Rounding.Floor
);

shares = _convertToShares(assets, Math.Rounding.Floor) - feeShares;
shares = _convertToShares(assets, Math.Rounding.Floor) - feeShares;
if (shares == 0) revert ZeroAmount();

if (feeShares > 0) _mint(feeRecipient, feeShares);
Expand All @@ -187,7 +187,10 @@ contract MultiStrategyVault is

IERC20(asset()).safeTransferFrom(msg.sender, address(this), assets);

strategies[0].deposit(assets, address(this));
// deposit into default index strategy or leave funds idle
if (defaultDepositIndex != type(uint256).max) {
strategies[defaultDepositIndex].deposit(assets, address(this));
}

emit Deposit(msg.sender, receiver, assets, shares);
}
Expand Down Expand Up @@ -217,10 +220,10 @@ contract MultiStrategyVault is
uint256 feeShares = shares.mulDiv(
depositFee,
1e18 - depositFee,
Math.Rounding.Floor
Math.Rounding.Floor
);

assets = _convertToAssets(shares + feeShares, Math.Rounding.Ceil);
assets = _convertToAssets(shares + feeShares, Math.Rounding.Ceil);

if (assets > maxMint(receiver)) revert MaxError(assets);

Expand All @@ -230,7 +233,10 @@ contract MultiStrategyVault is

IERC20(asset()).safeTransferFrom(msg.sender, address(this), assets);

strategies[0].deposit(assets, address(this));
// deposit into default index strategy or leave funds idle
if (defaultDepositIndex != type(uint256).max) {
strategies[defaultDepositIndex].deposit(assets, address(this));
}

emit Deposit(msg.sender, receiver, assets, shares);
}
Expand All @@ -254,15 +260,15 @@ contract MultiStrategyVault is
if (receiver == address(0)) revert InvalidReceiver();
if (assets > maxWithdraw(owner)) revert MaxError(assets);

shares = _convertToShares(assets, Math.Rounding.Ceil);
shares = _convertToShares(assets, Math.Rounding.Ceil);
if (shares == 0) revert ZeroAmount();

uint256 withdrawalFee = uint256(fees.withdrawal);

uint256 feeShares = shares.mulDiv(
withdrawalFee,
1e18 - withdrawalFee,
Math.Rounding.Floor
Math.Rounding.Floor
);

shares += feeShares;
Expand Down Expand Up @@ -305,10 +311,10 @@ contract MultiStrategyVault is
uint256 feeShares = shares.mulDiv(
uint256(fees.withdrawal),
1e18,
Math.Rounding.Floor
Math.Rounding.Floor
);

assets = _convertToAssets(shares - feeShares, Math.Rounding.Floor);
assets = _convertToAssets(shares - feeShares, Math.Rounding.Floor);

_burn(owner, shares);

Expand All @@ -322,26 +328,26 @@ contract MultiStrategyVault is
function _withdrawStrategyFunds(uint256 amount, address receiver) internal {
// caching
IERC20 asset_ = IERC20(asset());
uint256[] memory withdrawalQueue_ = withdrawalQueue;

// Get the Vault's floating balance.
uint256 float = asset_.balanceOf(address(this));

if (amount < float){
if (amount < float) {
asset_.safeTransfer(receiver, amount);
} else {
// If the amount is greater than the float, withdraw from strategies.
if (float > 0) {
asset_.safeTransfer(receiver, float);
}
// We'll start at the tip of the stack and traverse backwards.
uint256 currentIndex = strategies.length - 1;

// Iterate in reverse so we pull from the stack in a "last in, first out" manner.
// Iterate the withdrawal queue and get indexes
// Will revert due to underflow if we empty the stack before pulling the desired amount.
for (; ; currentIndex--) {
uint256 len = withdrawalQueue_.length;
for (uint256 i = 0; i < len; i++) {
uint256 missing = amount - float;

IERC4626 strategy = strategies[currentIndex];
IERC4626 strategy = strategies[withdrawalQueue_[i]];

uint256 withdrawableAssets = strategy.previewRedeem(
strategy.balanceOf(address(this))
Expand All @@ -359,7 +365,7 @@ contract MultiStrategyVault is
float += withdrawableAssets;
}
}
}
}
}

/*//////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -391,9 +397,9 @@ contract MultiStrategyVault is
assets -= assets.mulDiv(
uint256(fees.deposit),
1e18,
Math.Rounding.Floor
Math.Rounding.Floor
);
shares = _convertToShares(assets, Math.Rounding.Floor);
shares = _convertToShares(assets, Math.Rounding.Floor);
}

/**
Expand All @@ -409,9 +415,9 @@ contract MultiStrategyVault is
shares += shares.mulDiv(
depositFee,
1e18 - depositFee,
Math.Rounding.Floor
Math.Rounding.Floor
);
assets = _convertToAssets(shares, Math.Rounding.Ceil);
assets = _convertToAssets(shares, Math.Rounding.Ceil);
}

/**
Expand All @@ -423,13 +429,13 @@ contract MultiStrategyVault is
function previewWithdraw(
uint256 assets
) public view override returns (uint256 shares) {
shares = _convertToShares(assets, Math.Rounding.Ceil);
shares = _convertToShares(assets, Math.Rounding.Ceil);

uint256 withdrawalFee = uint256(fees.withdrawal);
shares += shares.mulDiv(
withdrawalFee,
1e18 - withdrawalFee,
Math.Rounding.Floor
Math.Rounding.Floor
);
}

Expand All @@ -445,10 +451,10 @@ contract MultiStrategyVault is
uint256 feeShares = shares.mulDiv(
uint256(fees.withdrawal),
1e18,
Math.Rounding.Floor
Math.Rounding.Floor
);

assets = _convertToAssets(shares - feeShares, Math.Rounding.Floor);
assets = _convertToAssets(shares - feeShares, Math.Rounding.Floor);
}

// TODO - is this now inherited anyways?
Expand Down Expand Up @@ -515,7 +521,7 @@ contract MultiStrategyVault is
? managementFee.mulDiv(
totalAssets() * (block.timestamp - feesUpdatedAt),
SECONDS_PER_YEAR,
Math.Rounding.Floor
Math.Rounding.Floor
) / 1e18
: 0;
}
Expand All @@ -536,7 +542,7 @@ contract MultiStrategyVault is
? performanceFee.mulDiv(
(shareValue - highWaterMark_) * totalSupply(),
1e36,
Math.Rounding.Floor
Math.Rounding.Floor
)
: 0;
}
Expand Down Expand Up @@ -573,7 +579,7 @@ contract MultiStrategyVault is
: totalFee.mulDiv(
supply,
currentAssets - totalFee,
Math.Rounding.Floor
Math.Rounding.Floor
);
_mint(feeRecipient, feeInShare);
}
Expand Down Expand Up @@ -657,11 +663,15 @@ contract MultiStrategyVault is
IERC4626[] public strategies;
IERC4626[] public proposedStrategies;
uint256 public proposedStrategyTime;
uint256 public defaultDepositIndex; // index of the strategy to deposit funds by default - if uint.max, leave funds idle
uint256[] public withdrawalQueue; // indexes of the strategy order in the withdrawal queue

event NewStrategiesProposed();
event ChangedStrategies();

error VaultAssetMismatchNewAdapterAsset();
error InvalidIndex();
error InvalidWithdrawalQueue();

function getStrategies() external view returns (IERC4626[] memory) {
return strategies;
Expand All @@ -671,6 +681,29 @@ contract MultiStrategyVault is
return proposedStrategies;
}

function setDefaultDepositIndex(uint256 index) external onlyOwner {
if (index > strategies.length - 1 && index != type(uint256).max)
revert InvalidIndex();

defaultDepositIndex = index;
}

function setWithdrawalQueue(uint256[] memory indexes) external onlyOwner {
if (indexes.length != strategies.length)
revert InvalidWithdrawalQueue();

withdrawalQueue = new uint256[](indexes.length);

for (uint256 i = 0; i < indexes.length; i++) {
uint256 index = indexes[i];

if (index > strategies.length - 1 && index != type(uint256).max)
revert InvalidIndex();

withdrawalQueue[i] = index;
}
}

/**
* @notice Propose a new adapter for this vault. Caller must be Owner.
* @param strategies_ A new ERC4626 that should be used as a yield adapter for this asset.
Expand Down
92 changes: 92 additions & 0 deletions test/vault/MultiStrategyVault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ contract MultiStrategyVaultTest is Test {
type(uint256).max,
address(this)
);

uint256[] memory withdrawalQueue = new uint256[](2);
withdrawalQueue[0] = 0;
withdrawalQueue[1] = 1;

vault.setWithdrawalQueue(withdrawalQueue);
}

/*//////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1014,10 +1020,96 @@ contract MultiStrategyVaultTest is Test {
vault.changeStrategies();
}

function testFail_changeWithdrawalQueue_invalidLength() public {
uint256[] memory withdrawalQueue = new uint256[](1);
withdrawalQueue[0] = 0;

vault.setWithdrawalQueue(withdrawalQueue);
}

function testFail_changeWithdrawalQueue_invalidIndex() public {
uint256[] memory withdrawalQueue = new uint256[](2);
withdrawalQueue[0] = 5;
withdrawalQueue[1] = 0;

vault.setWithdrawalQueue(withdrawalQueue);
}

/*//////////////////////////////////////////////////////////////
PULL AND PUSH FUNDS
//////////////////////////////////////////////////////////////*/

function test_deposit_fundsIdle() public {
// set default index to be type max
vault.setDefaultDepositIndex(type(uint256).max);

uint256 amount = 1e18;
_depositIntoVault(bob, amount);

assertEq(asset.balanceOf(address(strategies[0])), 0);
assertEq(asset.balanceOf(address(strategies[1])), 0);
assertEq(asset.balanceOf(address(vault)), amount);
}

function test_withdrawIdleFunds() public {
// set default index to be type max
vault.setDefaultDepositIndex(type(uint256).max);

uint256 amount = 1e18;
_depositIntoVault(bob, amount);

assertEq(asset.balanceOf(address(strategies[0])), 0);
assertEq(asset.balanceOf(address(strategies[1])), 0);
assertEq(asset.balanceOf(address(vault)), amount);

uint256 balBobBefore = asset.balanceOf(bob);

vm.prank(bob);
vault.withdraw(amount);

assertEq(asset.balanceOf(address(strategies[0])), 0);
assertEq(asset.balanceOf(address(strategies[1])), 0);
assertEq(asset.balanceOf(address(vault)), 0);

assertEq(asset.balanceOf(bob), balBobBefore + amount);
}

function test_withdraw_queueOrder() public {
_depositIntoVault(bob, 10e18);

assertEq(asset.balanceOf(address(strategies[0])), 10e18);
assertEq(asset.balanceOf(address(strategies[1])), 0);

Allocation[] memory allocations = new Allocation[](2);
allocations[0] = Allocation({index: 0, amount: 10e18});

vault.pullFunds(allocations);

allocations[0] = Allocation({index: 0, amount: 1e18});
allocations[1] = Allocation({index: 1, amount: 9e18});

vault.pushFunds(allocations);

assertEq(asset.balanceOf(address(strategies[1])), 9e18);
assertEq(asset.balanceOf(address(strategies[0])), 1e18);

uint256[] memory withdrawalQueue = new uint256[](2);
withdrawalQueue[0] = 1;
withdrawalQueue[1] = 0;

vault.setWithdrawalQueue(withdrawalQueue);

vm.prank(bob);
vault.withdraw(95e17);

assertEq(asset.balanceOf(address(strategies[1])), 0);
assertEq(asset.balanceOf(address(strategies[0])), 5e17);
}

function testFail_setDefaultIndex_invalidIndex() public {
vault.setDefaultDepositIndex(5);
}

function _depositIntoVault(address user, uint256 amount) internal {
asset.mint(user, amount);

Expand Down

0 comments on commit 1b19ecd

Please sign in to comment.