diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index f598409f..f6ec4534 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -1648,7 +1648,7 @@ async def cancel_request_remove_insurance_fund_stake(self, spot_market_index: in ) async def get_cancel_request_remove_insurance_fund_stake_ix( - self, spot_market_index: int + self, spot_market_index: int, user_token_account: Pubkey = None ): ra = self.get_remaining_accounts( writable_spot_market_indexes=[spot_market_index] @@ -1672,26 +1672,33 @@ async def get_cancel_request_remove_insurance_fund_stake_ix( "insurance_fund_vault": get_insurance_fund_vault_public_key( self.program_id, spot_market_index ), - "drift_signer": get_drift_client_signer_public_key(self.program_id), - "user_token_account": self.get_associated_token_account_public_key( - spot_market_index - ), - "token_program": TOKEN_PROGRAM_ID, }, remaining_accounts=ra, ), ) - async def remove_insurance_fund_stake(self, spot_market_index: int): + async def remove_insurance_fund_stake( + self, spot_market_index: int, user_token_account: Pubkey = None + ): return await self.send_ixs( - await self.get_remove_insurance_fund_stake_ix(spot_market_index) + await self.get_remove_insurance_fund_stake_ix( + spot_market_index, user_token_account + ) ) - async def get_remove_insurance_fund_stake_ix(self, spot_market_index: int): + async def get_remove_insurance_fund_stake_ix( + self, spot_market_index: int, user_token_account: Pubkey = None + ): ra = self.get_remaining_accounts( writable_spot_market_indexes=[spot_market_index], ) + user_token_account = ( + user_token_account + if user_token_account is not None + else self.get_associated_token_account_public_key(spot_market_index) + ) + return self.program.instruction["remove_insurance_fund_stake"]( spot_market_index, ctx=Context( @@ -1711,29 +1718,35 @@ async def get_remove_insurance_fund_stake_ix(self, spot_market_index: int): self.program_id, spot_market_index ), "drift_signer": get_drift_client_signer_public_key(self.program_id), - "user_token_account": self.get_associated_token_account_public_key( - spot_market_index - ), + "user_token_account": user_token_account, "token_program": TOKEN_PROGRAM_ID, }, remaining_accounts=ra, ), ) - async def add_insurance_fund_stake(self, spot_market_index: int, amount: int): + async def add_insurance_fund_stake( + self, spot_market_index: int, amount: int, user_token_account: Pubkey = None + ): return await self.send_ixs( - await self.get_add_insurance_fund_stake_ix(spot_market_index, amount) + await self.get_add_insurance_fund_stake_ix( + spot_market_index, amount, user_token_account + ) ) async def get_add_insurance_fund_stake_ix( - self, - spot_market_index: int, - amount: int, + self, spot_market_index: int, amount: int, user_token_account: Pubkey = None ): remaining_accounts = self.get_remaining_accounts( writable_spot_market_indexes=[spot_market_index], ) + user_token_account = ( + user_token_account + if user_token_account is not None + else self.get_associated_token_account_public_key(spot_market_index) + ) + return self.program.instruction["add_insurance_fund_stake"]( spot_market_index, amount, @@ -1757,9 +1770,7 @@ async def get_add_insurance_fund_stake_ix( self.program_id, spot_market_index ), "drift_signer": get_drift_client_signer_public_key(self.program_id), - "user_token_account": self.get_associated_token_account_public_key( - spot_market_index - ), + "user_token_account": user_token_account, "token_program": TOKEN_PROGRAM_ID, }, remaining_accounts=remaining_accounts, diff --git a/tests/test.py b/tests/test.py index 4c6d6d0a..81632a7a 100644 --- a/tests/test.py +++ b/tests/test.py @@ -206,7 +206,6 @@ async def test_usdc_deposit( ): usdc_spot_market = await get_spot_market_account(drift_client.program, 0) assert usdc_spot_market.market_index == 0 - drift_client.spot_market_atas[0] = user_usdc_account.pubkey() await drift_client.deposit( USDC_AMOUNT, 0, user_usdc_account.pubkey(), user_initialized=True ) @@ -372,7 +371,9 @@ async def test_stake_if( if_acc = await get_if_stake_account(drift_client.program, drift_client.authority, 0) assert if_acc.market_index == 0 - await drift_client.add_insurance_fund_stake(0, 1 * QUOTE_PRECISION) + await drift_client.add_insurance_fund_stake( + 0, 1 * QUOTE_PRECISION, user_usdc_account.pubkey() + ) user_stats = await get_user_stats_account( drift_client.program, drift_client.authority @@ -381,7 +382,7 @@ async def test_stake_if( await drift_client.request_remove_insurance_fund_stake(0, 1 * QUOTE_PRECISION) - await drift_client.remove_insurance_fund_stake(0) + await drift_client.remove_insurance_fund_stake(0, user_usdc_account.pubkey()) user_stats = await get_user_stats_account( drift_client.program, drift_client.authority