Skip to content

Commit

Permalink
fix if tests
Browse files Browse the repository at this point in the history
  • Loading branch information
crispheaney committed Nov 27, 2023
1 parent 018c536 commit 0f85813
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 23 deletions.
51 changes: 31 additions & 20 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,7 +1648,7 @@ async def cancel_request_remove_insurance_fund_stake(self, spot_market_index: in
)

async def get_cancel_request_remove_insurance_fund_stake_ix(
self, spot_market_index: int
self, spot_market_index: int, user_token_account: Pubkey = None
):
ra = self.get_remaining_accounts(
writable_spot_market_indexes=[spot_market_index]
Expand All @@ -1672,26 +1672,33 @@ async def get_cancel_request_remove_insurance_fund_stake_ix(
"insurance_fund_vault": get_insurance_fund_vault_public_key(
self.program_id, spot_market_index
),
"drift_signer": get_drift_client_signer_public_key(self.program_id),
"user_token_account": self.get_associated_token_account_public_key(
spot_market_index
),
"token_program": TOKEN_PROGRAM_ID,
},
remaining_accounts=ra,
),
)

async def remove_insurance_fund_stake(self, spot_market_index: int):
async def remove_insurance_fund_stake(
self, spot_market_index: int, user_token_account: Pubkey = None
):
return await self.send_ixs(
await self.get_remove_insurance_fund_stake_ix(spot_market_index)
await self.get_remove_insurance_fund_stake_ix(
spot_market_index, user_token_account
)
)

async def get_remove_insurance_fund_stake_ix(self, spot_market_index: int):
async def get_remove_insurance_fund_stake_ix(
self, spot_market_index: int, user_token_account: Pubkey = None
):
ra = self.get_remaining_accounts(
writable_spot_market_indexes=[spot_market_index],
)

user_token_account = (
user_token_account
if user_token_account is not None
else self.get_associated_token_account_public_key(spot_market_index)
)

return self.program.instruction["remove_insurance_fund_stake"](
spot_market_index,
ctx=Context(
Expand All @@ -1711,29 +1718,35 @@ async def get_remove_insurance_fund_stake_ix(self, spot_market_index: int):
self.program_id, spot_market_index
),
"drift_signer": get_drift_client_signer_public_key(self.program_id),
"user_token_account": self.get_associated_token_account_public_key(
spot_market_index
),
"user_token_account": user_token_account,
"token_program": TOKEN_PROGRAM_ID,
},
remaining_accounts=ra,
),
)

async def add_insurance_fund_stake(self, spot_market_index: int, amount: int):
async def add_insurance_fund_stake(
self, spot_market_index: int, amount: int, user_token_account: Pubkey = None
):
return await self.send_ixs(
await self.get_add_insurance_fund_stake_ix(spot_market_index, amount)
await self.get_add_insurance_fund_stake_ix(
spot_market_index, amount, user_token_account
)
)

async def get_add_insurance_fund_stake_ix(
self,
spot_market_index: int,
amount: int,
self, spot_market_index: int, amount: int, user_token_account: Pubkey = None
):
remaining_accounts = self.get_remaining_accounts(
writable_spot_market_indexes=[spot_market_index],
)

user_token_account = (
user_token_account
if user_token_account is not None
else self.get_associated_token_account_public_key(spot_market_index)
)

return self.program.instruction["add_insurance_fund_stake"](
spot_market_index,
amount,
Expand All @@ -1757,9 +1770,7 @@ async def get_add_insurance_fund_stake_ix(
self.program_id, spot_market_index
),
"drift_signer": get_drift_client_signer_public_key(self.program_id),
"user_token_account": self.get_associated_token_account_public_key(
spot_market_index
),
"user_token_account": user_token_account,
"token_program": TOKEN_PROGRAM_ID,
},
remaining_accounts=remaining_accounts,
Expand Down
7 changes: 4 additions & 3 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ async def test_usdc_deposit(
):
usdc_spot_market = await get_spot_market_account(drift_client.program, 0)
assert usdc_spot_market.market_index == 0
drift_client.spot_market_atas[0] = user_usdc_account.pubkey()
await drift_client.deposit(
USDC_AMOUNT, 0, user_usdc_account.pubkey(), user_initialized=True
)
Expand Down Expand Up @@ -372,7 +371,9 @@ async def test_stake_if(
if_acc = await get_if_stake_account(drift_client.program, drift_client.authority, 0)
assert if_acc.market_index == 0

await drift_client.add_insurance_fund_stake(0, 1 * QUOTE_PRECISION)
await drift_client.add_insurance_fund_stake(
0, 1 * QUOTE_PRECISION, user_usdc_account.pubkey()
)

user_stats = await get_user_stats_account(
drift_client.program, drift_client.authority
Expand All @@ -381,7 +382,7 @@ async def test_stake_if(

await drift_client.request_remove_insurance_fund_stake(0, 1 * QUOTE_PRECISION)

await drift_client.remove_insurance_fund_stake(0)
await drift_client.remove_insurance_fund_stake(0, user_usdc_account.pubkey())

user_stats = await get_user_stats_account(
drift_client.program, drift_client.authority
Expand Down

0 comments on commit 0f85813

Please sign in to comment.