diff --git a/contracts/ERC20Splitter.sol b/contracts/ERC20Splitter.sol index 894bb46..7531d69 100644 --- a/contracts/ERC20Splitter.sol +++ b/contracts/ERC20Splitter.sol @@ -7,8 +7,6 @@ import '@openzeppelin/contracts/security/ReentrancyGuard.sol'; contract ERC20Splitter is ReentrancyGuard { // tokenAddress => userAddress => balance mapping(address => mapping(address => uint256)) public balances; - // userAddress => tokenAddress[] - mapping(address => address[]) private userTokens; /** Events **/ @@ -59,36 +57,35 @@ contract ERC20Splitter is ReentrancyGuard { /// @notice Withdraw all tokens that the caller is entitled to. /// Tokens are automatically determined based on previous deposits. - function withdraw() external nonReentrant { - address[] storage senderTokens = userTokens[msg.sender]; + function withdraw(address[] calldata tokenAddresses) external nonReentrant { + uint256 tokenCount = tokenAddresses.length; + require(tokenCount > 0, 'ERC20Splitter: No tokens specified'); - if (senderTokens.length == 0) { - return; - } - - uint256[] memory withdrawnAmounts = new uint256[](senderTokens.length); + uint256[] memory withdrawnAmounts = new uint256[](tokenCount); - for (uint256 i = 0; i < senderTokens.length; i++) { - address tokenAddress = senderTokens[i]; + for (uint256 i = 0; i < tokenCount; i++) { + address tokenAddress = tokenAddresses[i]; uint256 amount = balances[tokenAddress][msg.sender]; - delete balances[tokenAddress][msg.sender]; + if (amount == 0) { + continue; // Skip if no balance + } + + balances[tokenAddress][msg.sender] = 0; if (tokenAddress == address(0)) { payable(msg.sender).transfer(amount); } else { require( IERC20(tokenAddress).transferFrom(address(this), msg.sender, amount), - 'ERC20Splitter: TransferFrom failed' + 'ERC20Splitter: Transfer failed' ); } withdrawnAmounts[i] = amount; } - emit Withdraw(msg.sender, userTokens[msg.sender], withdrawnAmounts); - - delete userTokens[msg.sender]; + emit Withdraw(msg.sender, tokenAddresses, withdrawnAmounts); } /** Internal Functions **/ @@ -125,7 +122,6 @@ contract ERC20Splitter is ReentrancyGuard { for (uint256 i = 0; i < recipients.length; i++) { uint256 recipientAmount = (amount * shares[i]) / MAX_SHARES; balances[tokenAddress][recipients[i]] += recipientAmount; - userTokens[recipients[i]].push(tokenAddress); } } }