Skip to content

Commit

Permalink
consolidate type files and add Account suffix to mirror ts
Browse files Browse the repository at this point in the history
  • Loading branch information
crispheaney committed Nov 26, 2023
1 parent 1c1daa5 commit a46a203
Show file tree
Hide file tree
Showing 21 changed files with 167 additions and 597 deletions.
475 changes: 0 additions & 475 deletions src/driftpy/_types.py

This file was deleted.

13 changes: 9 additions & 4 deletions src/driftpy/accounts/cache/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from driftpy.accounts.types import DriftClientAccountSubscriber, DataAndSlot
from typing import Optional

from driftpy.types import PerpMarket, SpotMarket, OraclePriceData, State
from driftpy.types import (
PerpMarketAccount,
SpotMarketAccount,
OraclePriceData,
StateAccount,
)


class CachedDriftClientAccountSubscriber(DriftClientAccountSubscriber):
Expand Down Expand Up @@ -70,19 +75,19 @@ async def update_cache(self):

self.cache["oracle_price_data"] = oracle_data

async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]:
async def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
await self.cache_if_needed()
return self.cache["state"]

async def get_perp_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[PerpMarket]]:
) -> Optional[DataAndSlot[PerpMarketAccount]]:
await self.cache_if_needed()
return self.cache["perp_markets"][market_index]

async def get_spot_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[SpotMarket]]:
) -> Optional[DataAndSlot[SpotMarketAccount]]:
await self.cache_if_needed()
return self.cache["spot_markets"][market_index]

Expand Down
4 changes: 2 additions & 2 deletions src/driftpy/accounts/cache/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from driftpy.accounts import get_user_account_and_slot
from driftpy.accounts import UserAccountSubscriber, DataAndSlot
from driftpy.types import User
from driftpy.types import UserAccount


class CachedUserAccountSubscriber(UserAccountSubscriber):
Expand All @@ -28,7 +28,7 @@ async def update_cache(self):
user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey)
self.user_and_slot = user_and_slot

async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]:
async def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
await self.cache_if_needed()
return self.user_and_slot

Expand Down
26 changes: 14 additions & 12 deletions src/driftpy/accounts/get_accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,61 +34,63 @@ async def get_account_data_and_slot(
return DataAndSlot(slot, decoded_data)


async def get_state_account_and_slot(program: Program) -> DataAndSlot[State]:
async def get_state_account_and_slot(program: Program) -> DataAndSlot[StateAccount]:
state_public_key = get_state_public_key(program.program_id)
return await get_account_data_and_slot(state_public_key, program)


async def get_state_account(program: Program) -> State:
async def get_state_account(program: Program) -> StateAccount:
return (await get_state_account_and_slot(program)).data


async def get_if_stake_account(
program: Program, authority: Pubkey, spot_market_index: int
) -> InsuranceFundStake:
) -> InsuranceFundStakeAccount:
if_stake_pk = get_insurance_fund_stake_public_key(
program.program_id, authority, spot_market_index
)
response = await program.account["InsuranceFundStake"].fetch(if_stake_pk)
return cast(InsuranceFundStake, response)
return cast(InsuranceFundStakeAccount, response)


async def get_user_stats_account(
program: Program,
authority: Pubkey,
) -> UserStats:
) -> UserStatsAccount:
user_stats_public_key = get_user_stats_account_public_key(
program.program_id,
authority,
)
response = await program.account["UserStats"].fetch(user_stats_public_key)
return cast(UserStats, response)
return cast(UserStatsAccount, response)


async def get_user_account_and_slot(
program: Program,
user_public_key: Pubkey,
) -> DataAndSlot[User]:
) -> DataAndSlot[UserAccount]:
return await get_account_data_and_slot(user_public_key, program)


async def get_user_account(
program: Program,
user_public_key: Pubkey,
) -> User:
) -> UserAccount:
return (await get_user_account_and_slot(program, user_public_key)).data


async def get_perp_market_account_and_slot(
program: Program, market_index: int
) -> Optional[DataAndSlot[PerpMarket]]:
) -> Optional[DataAndSlot[PerpMarketAccount]]:
perp_market_public_key = get_perp_market_public_key(
program.program_id, market_index
)
return await get_account_data_and_slot(perp_market_public_key, program)


async def get_perp_market_account(program: Program, market_index: int) -> PerpMarket:
async def get_perp_market_account(
program: Program, market_index: int
) -> PerpMarketAccount:
return (await get_perp_market_account_and_slot(program, market_index)).data


Expand All @@ -98,7 +100,7 @@ async def get_all_perp_market_accounts(program: Program) -> list[ProgramAccount]

async def get_spot_market_account_and_slot(
program: Program, spot_market_index: int
) -> DataAndSlot[SpotMarket]:
) -> DataAndSlot[SpotMarketAccount]:
spot_market_public_key = get_spot_market_public_key(
program.program_id, spot_market_index
)
Expand All @@ -107,7 +109,7 @@ async def get_spot_market_account_and_slot(

async def get_spot_market_account(
program: Program, spot_market_index: int
) -> SpotMarket:
) -> SpotMarketAccount:
return (await get_spot_market_account_and_slot(program, spot_market_index)).data


Expand Down
16 changes: 11 additions & 5 deletions src/driftpy/accounts/polling/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from driftpy.accounts.bulk_account_loader import BulkAccountLoader
from driftpy.accounts.oracle import get_oracle_decode_fn
from driftpy.addresses import get_state_public_key
from driftpy.types import PerpMarket, SpotMarket, OraclePriceData, State, OracleSource
from driftpy.types import (
PerpMarketAccount,
SpotMarketAccount,
OraclePriceData,
StateAccount,
OracleSource,
)


class PollingDriftClientAccountSubscriber(DriftClientAccountSubscriber):
Expand All @@ -24,7 +30,7 @@ def __init__(
self.is_subscribed = False
self.callbacks: dict[str, int] = {}

self.state: Optional[DataAndSlot[State]] = None
self.state: Optional[DataAndSlot[StateAccount]] = None
self.perp_markets = {}
self.spot_markets = {}
self.oracle = {}
Expand Down Expand Up @@ -143,17 +149,17 @@ def unsubscribe(self):
)
self.callbacks.clear()

async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]:
async def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
return self.state

async def get_perp_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[PerpMarket]]:
) -> Optional[DataAndSlot[PerpMarketAccount]]:
return self.perp_markets.get(market_index)

async def get_spot_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[SpotMarket]]:
) -> Optional[DataAndSlot[SpotMarketAccount]]:
return self.spot_markets.get(market_index)

async def get_oracle_price_data_and_slot(
Expand Down
8 changes: 4 additions & 4 deletions src/driftpy/accounts/polling/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from driftpy.accounts import UserAccountSubscriber, DataAndSlot

from driftpy.accounts.bulk_account_loader import BulkAccountLoader
from driftpy.types import User
from driftpy.types import UserAccount


class PollingUserAccountSubscriber(UserAccountSubscriber):
Expand All @@ -19,7 +19,7 @@ def __init__(
self.bulk_account_loader = bulk_account_loader
self.program = program
self.user_account_pubkey = user_account_pubkey
self.data_and_slot: Optional[DataAndSlot[User]] = None
self.data_and_slot: Optional[DataAndSlot[UserAccount]] = None
self.decode = self.program.coder.accounts.decode
self.callback_id = None

Expand Down Expand Up @@ -53,7 +53,7 @@ def _account_loader_callback(self, buffer: bytes, slot: int):
async def fetch(self):
await self.bulk_account_loader.load()

def _update_data(self, new_data: Optional[DataAndSlot[User]]):
def _update_data(self, new_data: Optional[DataAndSlot[UserAccount]]):
if new_data is None:
return

Expand All @@ -70,5 +70,5 @@ def unsubscribe(self):

self.callback_id = None

async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]:
async def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
return self.data_and_slot
16 changes: 8 additions & 8 deletions src/driftpy/accounts/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from solders.pubkey import Pubkey

from driftpy.types import (
PerpMarket,
SpotMarket,
PerpMarketAccount,
SpotMarketAccount,
OracleSource,
User,
UserAccount,
OraclePriceData,
State,
StateAccount,
)

T = TypeVar("T")
Expand All @@ -32,19 +32,19 @@ def unsubscribe(self):
pass

@abstractmethod
async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]:
async def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
pass

@abstractmethod
async def get_perp_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[PerpMarket]]:
) -> Optional[DataAndSlot[PerpMarketAccount]]:
pass

@abstractmethod
async def get_spot_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[SpotMarket]]:
) -> Optional[DataAndSlot[SpotMarketAccount]]:
pass

@abstractmethod
Expand All @@ -64,5 +64,5 @@ def unsubscribe(self):
pass

@abstractmethod
async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]:
async def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
pass
19 changes: 12 additions & 7 deletions src/driftpy/accounts/ws/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from typing import Optional

from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber
from driftpy.types import PerpMarket, SpotMarket, OraclePriceData, State
from driftpy.types import (
PerpMarketAccount,
SpotMarketAccount,
OraclePriceData,
StateAccount,
)

from driftpy.addresses import *

Expand All @@ -26,7 +31,7 @@ def __init__(self, program: Program, commitment: Commitment = "confirmed"):

async def subscribe(self):
state_public_key = get_state_public_key(self.program.program_id)
self.state_subscriber = WebsocketAccountSubscriber[State](
self.state_subscriber = WebsocketAccountSubscriber[StateAccount](
state_public_key, self.program, self.commitment
)
await self.state_subscriber.subscribe()
Expand All @@ -44,7 +49,7 @@ async def subscribe_to_spot_market(self, market_index: int):
spot_market_public_key = get_spot_market_public_key(
self.program.program_id, market_index
)
spot_market_subscriber = WebsocketAccountSubscriber[SpotMarket](
spot_market_subscriber = WebsocketAccountSubscriber[SpotMarketAccount](
spot_market_public_key, self.program, self.commitment
)
await spot_market_subscriber.subscribe()
Expand All @@ -60,7 +65,7 @@ async def subscribe_to_perp_market(self, market_index: int):
perp_market_public_key = get_perp_market_public_key(
self.program.program_id, market_index
)
perp_market_subscriber = WebsocketAccountSubscriber[PerpMarket](
perp_market_subscriber = WebsocketAccountSubscriber[PerpMarketAccount](
perp_market_public_key, self.program, self.commitment
)
await perp_market_subscriber.subscribe()
Expand All @@ -87,17 +92,17 @@ async def subscribe_to_oracle(self, oracle: Pubkey, oracle_source: OracleSource)
await oracle_subscriber.subscribe()
self.oracle_subscribers[str(oracle)] = oracle_subscriber

async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]:
async def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
return self.state_subscriber.data_and_slot

async def get_perp_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[PerpMarket]]:
) -> Optional[DataAndSlot[PerpMarketAccount]]:
return self.perp_market_subscribers[market_index].data_and_slot

async def get_spot_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[SpotMarket]]:
) -> Optional[DataAndSlot[SpotMarketAccount]]:
return self.spot_market_subscribers[market_index].data_and_slot

async def get_oracle_price_data_and_slot(
Expand Down
6 changes: 3 additions & 3 deletions src/driftpy/accounts/ws/user.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Optional

from driftpy.accounts import DataAndSlot
from driftpy.types import User
from driftpy.types import UserAccount

from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber
from driftpy.accounts.types import UserAccountSubscriber


class WebsocketUserAccountSubscriber(
WebsocketAccountSubscriber[User], UserAccountSubscriber
WebsocketAccountSubscriber[UserAccount], UserAccountSubscriber
):
async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]:
async def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
return self.data_and_slot
Loading

0 comments on commit a46a203

Please sign in to comment.