diff --git a/contracts/OriginalTokenBridgeUpgradable.sol b/contracts/OriginalTokenBridgeUpgradable.sol index 1b951e4..ed07fa7 100644 --- a/contracts/OriginalTokenBridgeUpgradable.sol +++ b/contracts/OriginalTokenBridgeUpgradable.sol @@ -11,6 +11,12 @@ import {IWETH} from "./interfaces/IWETH.sol"; contract OriginalTokenBridgeUpgradable is TokenBridgeBaseUpgradable { using SafeERC20 for IERC20; + /// @notice Total bps representing 100% + uint16 public constant TOTAL_BPS = 10000; + + /// @notice An optional fee charged on withdrawal, expressed in bps. E.g., 1bps = 0.01% + uint16 public withdrawalFeeBps; + /// @notice Tokens that can be bridged to the remote chain mapping(address => bool) public supportedTokens; @@ -32,6 +38,7 @@ contract OriginalTokenBridgeUpgradable is TokenBridgeBaseUpgradable { event SetRemoteChainId(uint16 remoteChainId); event RegisterToken(address token); event WithdrawFee(address indexed token, address to, uint amount); + event SetWithdrawalFeeBps(uint16 withdrawalFeeBps); function __OriginalTokenBridgeBaseUpgradable_init(address _endpoint, uint16 _remoteChainId, address _weth) internal onlyInitializing { require(_weth != address(0), "OriginalTokenBridge: invalid WETH address"); @@ -40,7 +47,7 @@ contract OriginalTokenBridgeUpgradable is TokenBridgeBaseUpgradable { weth = _weth; } - function initialize(address _endpoint, uint16 _remoteChainId, address _weth) virtual external initializer { + function initialize(address _endpoint, uint16 _remoteChainId, address _weth) external virtual initializer { __OriginalTokenBridgeBaseUpgradable_init(_endpoint, _remoteChainId, _weth); } @@ -75,6 +82,12 @@ contract OriginalTokenBridgeUpgradable is TokenBridgeBaseUpgradable { return lzEndpoint.estimateFees(remoteChainId, address(this), payload, useZro, adapterParams); } + function setWithdrawalFeeBps(uint16 _withdrawalFeeBps) external onlyOwner { + require(_withdrawalFeeBps < TOTAL_BPS, "OriginalTokenBridge: invalid withdrawal fee bps"); + withdrawalFeeBps = _withdrawalFeeBps; + emit SetWithdrawalFeeBps(_withdrawalFeeBps); + } + /// @notice Bridges ERC20 to the remote chain /// @dev Locks an ERC20 on the source chain and sends LZ message to the remote chain to mint a wrapped token function bridge(address token, uint amountLD, address to, LzLib.CallParams calldata callParams, bytes memory adapterParams) external payable nonReentrant { @@ -89,7 +102,6 @@ contract OriginalTokenBridgeUpgradable is TokenBridgeBaseUpgradable { if (dust > 0) { IERC20(token).safeTransfer(msg.sender, dust); } - _bridge(token, amountWithoutDustLD, to, msg.value, callParams, adapterParams); } @@ -107,13 +119,20 @@ contract OriginalTokenBridgeUpgradable is TokenBridgeBaseUpgradable { require(to != address(0), "OriginalTokenBridge: invalid to"); _checkAdapterParams(remoteChainId, PT_MINT, adapterParams); + uint withdrawalAmountLD = amountLD; + if (withdrawalFeeBps > 0) { + uint withdrawalFee = (amountLD * withdrawalFeeBps) / TOTAL_BPS; + withdrawalAmountLD -= withdrawalFee; + } + uint amountSD = _amountLDtoSD(token, amountLD); - require(amountSD > 0, "OriginalTokenBridge: invalid amount"); + uint withdrawalAmountSD = _amountLDtoSD(token, withdrawalAmountLD); + require(amountSD > 0 && withdrawalAmountSD > 0, "OriginalTokenBridge: invalid amount"); - totalValueLockedSD[token] += amountSD; - bytes memory payload = abi.encode(PT_MINT, token, to, amountSD); + totalValueLockedSD[token] += withdrawalAmountSD; + bytes memory payload = abi.encode(PT_MINT, token, to, withdrawalAmountSD); _lzSend(remoteChainId, payload, callParams.refundAddress, callParams.zroPaymentAddress, adapterParams, nativeFee); - emit SendToken(token, msg.sender, to, amountLD); + emit SendToken(token, msg.sender, to, withdrawalAmountSD); } function withdrawFee(address token, address to, uint amountLD) public onlyOwner { diff --git a/test/OriginalTokenBridge.test.js b/test/OriginalTokenBridge.test.js index c3569c0..526fdbe 100644 --- a/test/OriginalTokenBridge.test.js +++ b/test/OriginalTokenBridge.test.js @@ -307,6 +307,60 @@ describe("OriginalTokenBridge", () => { }) }) + describe("sets withdrawal fee", () => { + const withdrawalFeeBps = 100 + it("reverts when called by non owner", async () => { + await expect(originalTokenBridge.connect(user).setWithdrawalFeeBps(withdrawalFeeBps)).to.be.revertedWith("Ownable: caller is not the owner") + }) + + it("reverts when withdrawal fee is greater than 10000", async () => { + const invalidWithdrawalFeeBps = 10001 + await expect(originalTokenBridge.setWithdrawalFeeBps(invalidWithdrawalFeeBps)).to.be.revertedWith("OriginalTokenBridge: invalid withdrawal fee") + }) + + it("sets withdrawal fee", async () => { + await originalTokenBridge.setWithdrawalFeeBps(withdrawalFeeBps) + expect(await originalTokenBridge.withdrawalFeeBps()).to.be.eq(withdrawalFeeBps) + }) + }) + + describe("bridges and withdraw fee", () => { + const withdrawalFeeBps = 100 + let fee + let totalAmount + beforeEach(async () => { + await originalTokenBridge.setWithdrawalFeeBps(withdrawalFeeBps) + fee = (await originalTokenBridge.estimateBridgeFee(false, adapterParams)).nativeFee + totalAmount = amount + fee + await originalToken.connect(user).approve(originalTokenBridge.target, amount) + }) + + it("bridges ERC20 token and withdraws fees", async () => { + await originalTokenBridge.registerToken(originalToken.target, sharedDecimals) + await originalTokenBridge.connect(user).bridge(originalToken.target, amount, user.address, callParams, adapterParams, { value: fee }) + const LDtoSD = await originalTokenBridge.LDtoSDConversionRate(originalToken.target) + + const withdrawalFee = (amount * BigInt(withdrawalFeeBps)) / BigInt(10000) / LDtoSD + const withdrawalAmount = amount / LDtoSD - withdrawalFee + + await originalTokenBridge.connect(owner).withdrawFee(originalToken.target, owner.address, withdrawalFee) + expect(await originalToken.balanceOf(owner.address)).to.be.eq(withdrawalFee) + expect(await originalTokenBridge.totalValueLockedSD(originalToken.target)).to.be.eq(withdrawalAmount) + }) + + it("bridges WETH and withdraws fees", async () => { + await originalTokenBridge.registerToken(weth.target, wethSharedDecimals) + await originalTokenBridge.connect(user).bridgeNative(amount, user.address, callParams, adapterParams, { value: totalAmount }) + + const withdrawalFee = (amount * BigInt(withdrawalFeeBps)) / BigInt(10000) + const withdrawalAmount = amount - withdrawalFee + + await originalTokenBridge.connect(owner).withdrawFee(weth.target, owner.address, withdrawalFee) + expect(await weth.balanceOf(owner.address)).to.be.eq(withdrawalFee) + expect(await originalTokenBridge.totalValueLockedSD(weth.target)).to.be.eq(withdrawalAmount) + }) + }) + describe("Upgrades Contract", () => { beforeEach(async () => { originalTokenBridgeV2Factory = await ethers.getContractFactory("OriginalTokenBridgeHarnessUpgradableV2")