diff --git a/src/driftpy/_types.py b/src/driftpy/_types.py deleted file mode 100644 index 531311e6..00000000 --- a/src/driftpy/_types.py +++ /dev/null @@ -1,475 +0,0 @@ -from driftpy.constants.numeric_constants import SPOT_RATE_PRECISION -from driftpy.types import OracleSource -from typing import Optional, Any -from dataclasses import dataclass -from sumtypes import constructor # type: ignore -from borsh_construct.enum import _rust_enum -from solana.publickey import PublicKey - - -@dataclass -class PriceDivergence: - mark_oracle_divergence_numerator: int - mark_oracle_divergence_denominator: int - - -@dataclass -class Validity: - slots_before_stale: int - confidence_interval_max_size: int - too_volatile_ratio: int - - -@dataclass -class OracleGuardRails: - price_divergence: PriceDivergence - validity: Validity - use_for_liquidations: bool - - -@dataclass -class DiscountTokenTier: - minimum_balance: int - discount_numerator: int - discount_denominator: int - - -@dataclass -class DiscountTokenTiers: - first_tier: DiscountTokenTier - second_tier: DiscountTokenTier - third_tier: DiscountTokenTier - fourth_tier: DiscountTokenTier - - -@dataclass -class ReferralDiscount: - referrer_reward_numerator: int - referrer_reward_denominator: int - referee_discount_numerator: int - referee_discount_denominator: int - - -@dataclass -class OrderFillerRewardStructure: - reward_numerator: int - reward_denominator: int - time_based_reward_lower_bound: int - - -@dataclass -class FeeStructure: - fee_numerator: int - fee_denominator: int - discount_token_tiers: DiscountTokenTiers - referral_discount: ReferralDiscount - maker_rebate_numerator: int - maker_rebate_denominator: int - filler_reward_structure: OrderFillerRewardStructure - - -@dataclass -class StateAccount: - admin: PublicKey - exchange_paused: bool - funding_paused: bool - admin_controls_prices: bool - insurance_vault: PublicKey - insurance_vault_authority: PublicKey - insurance_vault_nonce: int - margin_ratio_initial: int - margin_ratio_maintenance: int - margin_ratio_partial: int - partial_liquidation_close_percentage_numerator: int - partial_liquidation_close_percentage_denominator: int - partial_liquidation_penalty_percentage_numerator: int - partial_liquidation_penalty_percentage_denominator: int - full_liquidation_penalty_percentage_numerator: int - full_liquidation_penalty_percentage_denominator: int - partial_liquidation_liquidator_share_denominator: int - full_liquidation_liquidator_share_denominator: int - fee_structure: FeeStructure - whitelist_mint: PublicKey - discount_mint: PublicKey - oracle_guard_rails: OracleGuardRails - number_of_markets: int - number_of_banks: int - min_order_quote_asset_amount: int - order_auction_duration: int - # upgrade-ability - padding0: int - padding1: int - - -# --- - - -@_rust_enum -class OracleSource: - Pyth = constructor() - Switchboard = constructor() - QuoteAsset = constructor() - Pyth_1K = constructor() - Pyth_1M = constructor() - - -@_rust_enum -class DepositDirection: - DEPOSIT = constructor() - WITHDRAW = constructor() - - -@dataclass -class TradeDirection: - long: Optional[Any] - short: Optional[Any] - - -@_rust_enum -class OrderType: - MARKET = constructor() - LIMIT = constructor() - TRIGGER_MARKET = constructor() - TRIGGER_LIMIT = constructor() - ORACLE = constructor() - - -@_rust_enum -class OrderStatus: - INIT = constructor() - OPEN = constructor() - FILLED = constructor() - CANCELED = constructor() - - -@_rust_enum -class OrderDiscountTier: - NONE = constructor() - FIRST = constructor() - SECOND = constructor() - THIRD = constructor() - FOURTH = constructor() - - -@_rust_enum -class OrderTriggerCondition: - ABOVE = constructor() - BELOW = constructor() - - -@_rust_enum -class OrderAction: - PLACE = constructor() - FILL = constructor() - CANCEL = constructor() - - -@_rust_enum -class PositionDirection: - LONG = constructor() - SHORT = constructor() - - -@_rust_enum -class SwapDirection: - ADD = constructor() - REMOVE = constructor() - - -@_rust_enum -class AssetType: - QUOTE = constructor() - BASE = constructor() - - -@_rust_enum -class BankBalanceType: - DEPOSIT = constructor() - BORROW = constructor() - - -# --- - - -@dataclass -class Order: - status: OrderStatus - order_type: OrderType - ts: int - order_id: int - user_order_id: int - market_index: int - price: int - user_base_asset_amount: int - base_asset_amount: int - base_asset_amount_filled: int - quote_asset_amount: int - quote_asset_amount_filled: int - fee: int - direction: PositionDirection - reduce_only: bool - trigger_price: int - trigger_condition: OrderTriggerCondition - discount_tier: OrderDiscountTier - referrer: PublicKey - post_only: bool - immediate_or_cancel: bool - oracle_price_offset: int - - -@dataclass -class OrderParamsOptionalAccounts: - discount_token: bool = False - referrer: bool = False - - -@dataclass -class OrderParams: - # necessary - order_type: OrderType - direction: PositionDirection - market_index: int - base_asset_amount: int - # optional - user_order_id: int = 0 - price: int = 0 - reduce_only: bool = False - post_only: bool = False - immediate_or_cancel: bool = False - trigger_price: int = 0 - trigger_condition: OrderTriggerCondition = OrderTriggerCondition.ABOVE() - position_limit: int = 0 - oracle_price_offset: int = 0 - auction_duration: int = 0 - padding0: bool = 0 - padding1: bool = 0 - optional_accounts: OrderParamsOptionalAccounts = OrderParamsOptionalAccounts() - - -@dataclass -class MakerInfo: - maker: PublicKey - order: Order - - -@dataclass -class OrderFillerRewardStructure: - reward_numerator: int - reward_denominator: int - time_based_reward_lower_bound: int # minimum time filler reward - - -@dataclass -class MarketPosition: - market_index: int - base_asset_amount: int - quote_asset_amount: int - quote_entry_amount: int - last_cumulative_funding_rate: int - last_cumulative_repeg_rebate: int - last_funding_rate_ts: int - open_orders: int - unsettled_pnl: int - open_bids: int - open_asks: int - - # lp stuff - lp_shares: int - lp_base_asset_amount: int - lp_quote_asset_amount: int - last_cumulative_funding_payment_per_lp: int - last_cumulative_fee_per_lp: int - last_cumulative_base_asset_amount_with_amm_per_lp: int - last_lp_add_time: int - - # upgrade-ability - padding0: int - padding1: int - padding2: int - padding3: int - padding4: int - - # dw why this doesnt register :( - # def is_available(self): - # return self.base_asset_amount == 0 and self.open_orders == 0 and - # self.lp_shares == 0 - - -@dataclass -class UserFees: - total_fee_paid: int - total_fee_rebate: int - total_token_discount: int - total_referral_reward: int - total_referee_discount: int - - -@dataclass -class UserBankBalance: - bank_index: int - balance_type: BankBalanceType - balance: int - - -@dataclass -class User: - authority: PublicKey - user_id: int - name: list[int] - bank_balances: list[UserBankBalance] - fees: UserFees - next_order_id: int - positions: list[MarketPosition] - orders: list[Order] - - -@dataclass -class PoolBalance: - balance: int - - -@dataclass -class Bank: - bank_index: int - pubkey: PublicKey - oracle: PublicKey - oracle_source: OracleSource - mint: PublicKey - vault: PublicKey - vault_authority: PublicKey - vault_authority_nonce: int - decimals: int - optimal_utilization: int - optimal_borrow_rate: int - max_borrow_rate: int - deposit_balance: int - borrow_balance: int - cumulative_deposit_interest: int - cumulative_borrow_interest: int - last_updated: int - initial_asset_weight: int - maintenance_asset_weight: int - initial_liability_weight: int - maintenance_liability_weight: int - - -@dataclass -class AMM: - oracle: PublicKey - oracle_source: OracleSource = OracleSource.Pyth() - last_oracle_price: int = 0 - last_oracle_conf_pct: int = 0 - last_oracle_delay: int = 0 - last_oracle_normalised_price: int = 0 - last_oracle_price_twap: int = 0 - last_oracle_price_twap_ts: int = 0 - last_oracle_mark_spread_pct: int = 0 - - base_asset_reserve: int = 0 - quote_asset_reserve: int = 0 - sqrt_k: int = 0 - peg_multiplier: int = 0 - - terminal_quote_asset_reserve: int = 0 - base_asset_amount_with_amm: int = 0 - base_asset_amount_with_unsettled_lp: int = 0 - - base_asset_amount_long: int = 0 - base_asset_amount_short: int = 0 - - quote_asset_amount_long: int = 0 - quote_asset_amount_short: int = 0 - - # lp stuff - cumulative_funding_payment_per_lp: int = 0 - cumulative_fee_per_lp: int = 0 - cumulative_base_asset_amount_with_amm_per_lp: int = 0 - lp_cooldown_time: int = 0 - user_lp_shares: int = 0 - - # funding - last_funding_rate: int = 0 - last_funding_rate_ts: int = 0 - funding_period: int = 0 - cumulative_funding_rate_long: int = 0 - cumulative_funding_rate_short: int = 0 - cumulative_repeg_rebate_long: int = 0 - cumulative_repeg_rebate_short: int = 0 - - mark_std: int = 0 - last_mark_price_twap: int = 0 - last_mark_price_twap_ts: int = 0 - - # trade constraints - minimum_quote_asset_trade_size: int = 0 - base_asset_amount_step_size: int = 0 - - # market making - base_spread: int = 0 - long_spread: int = 0 - short_spread: int = 0 - max_spread: int = 0 - ask_base_asset_reserve: int = 0 - ask_quote_asset_reserve: int = 0 - bid_base_asset_reserve: int = 0 - bid_quote_asset_reserve: int = 0 - - last_bid_price_twap: int = 0 - last_ask_price_twap: int = 0 - - long_intensity_count: int = 0 - long_intensity_volume: int = 0 - short_intensity_count: int = 0 - short_intensity_volume: int = 0 - curve_update_intensity: int = 0 - - # fee tracking - total_fee: int = 0 - total_mm_fee: int = 0 - total_exchange_fee: int = 0 - total_fee_minus_distributions: int = 0 - total_fee_withdrawn: int = 0 - net_revenue_since_last_funding: int = 0 - fee_pool: int = 0 - last_update_slot: int = 0 - - padding0: int = 0 - padding1: int = 0 - padding2: int = 0 - padding3: int = 0 - - -@dataclass -class Market: - market_index: int - amm: AMM - pubkey: PublicKey = PublicKey(0) - initialized: bool = True - base_asset_amount_long: int = 0 - base_asset_amount_short: int = 0 - number_of_users: int = 0 - margin_ratio_initial: int = 1000 - margin_ratio_partial: int = 500 - margin_ratio_maintenance: int = 625 - next_fill_record_id: int = 0 - next_funding_rate_record_id: int = 0 - next_curve_record_id: int = 0 - pnl_pool: PoolBalance = PoolBalance(0) - unsettled_profit: int = 0 - unsettled_loss: int = 0 - - padding0: int = 0 - padding1: int = 0 - padding2: int = 0 - padding3: int = 0 - padding4: int = 0 - - -@dataclass -class SpotMarket: - mint: PublicKey # this - oracle: PublicKey = PublicKey([0] * PublicKey.LENGTH) # this - oracle_source: OracleSource = OracleSource.QUOTE_ASSET() - optimal_utilization: int = SPOT_RATE_PRECISION // 2 - optimal_rate: int = SPOT_RATE_PRECISION - max_rate: int = SPOT_RATE_PRECISION diff --git a/src/driftpy/accounts/cache/drift_client.py b/src/driftpy/accounts/cache/drift_client.py index 75905249..5ff16402 100644 --- a/src/driftpy/accounts/cache/drift_client.py +++ b/src/driftpy/accounts/cache/drift_client.py @@ -11,7 +11,12 @@ from driftpy.accounts.types import DriftClientAccountSubscriber, DataAndSlot from typing import Optional -from driftpy.types import PerpMarket, SpotMarket, OraclePriceData, State +from driftpy.types import ( + PerpMarketAccount, + SpotMarketAccount, + OraclePriceData, + StateAccount, +) class CachedDriftClientAccountSubscriber(DriftClientAccountSubscriber): @@ -70,19 +75,19 @@ async def update_cache(self): self.cache["oracle_price_data"] = oracle_data - async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]: + async def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]: await self.cache_if_needed() return self.cache["state"] async def get_perp_market_and_slot( self, market_index: int - ) -> Optional[DataAndSlot[PerpMarket]]: + ) -> Optional[DataAndSlot[PerpMarketAccount]]: await self.cache_if_needed() return self.cache["perp_markets"][market_index] async def get_spot_market_and_slot( self, market_index: int - ) -> Optional[DataAndSlot[SpotMarket]]: + ) -> Optional[DataAndSlot[SpotMarketAccount]]: await self.cache_if_needed() return self.cache["spot_markets"][market_index] diff --git a/src/driftpy/accounts/cache/user.py b/src/driftpy/accounts/cache/user.py index 068cdbe7..5134627f 100644 --- a/src/driftpy/accounts/cache/user.py +++ b/src/driftpy/accounts/cache/user.py @@ -6,7 +6,7 @@ from driftpy.accounts import get_user_account_and_slot from driftpy.accounts import UserAccountSubscriber, DataAndSlot -from driftpy.types import User +from driftpy.types import UserAccount class CachedUserAccountSubscriber(UserAccountSubscriber): @@ -28,7 +28,7 @@ async def update_cache(self): user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey) self.user_and_slot = user_and_slot - async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]: + async def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]: await self.cache_if_needed() return self.user_and_slot diff --git a/src/driftpy/accounts/get_accounts.py b/src/driftpy/accounts/get_accounts.py index 812fb687..fb5602a8 100644 --- a/src/driftpy/accounts/get_accounts.py +++ b/src/driftpy/accounts/get_accounts.py @@ -34,61 +34,63 @@ async def get_account_data_and_slot( return DataAndSlot(slot, decoded_data) -async def get_state_account_and_slot(program: Program) -> DataAndSlot[State]: +async def get_state_account_and_slot(program: Program) -> DataAndSlot[StateAccount]: state_public_key = get_state_public_key(program.program_id) return await get_account_data_and_slot(state_public_key, program) -async def get_state_account(program: Program) -> State: +async def get_state_account(program: Program) -> StateAccount: return (await get_state_account_and_slot(program)).data async def get_if_stake_account( program: Program, authority: Pubkey, spot_market_index: int -) -> InsuranceFundStake: +) -> InsuranceFundStakeAccount: if_stake_pk = get_insurance_fund_stake_public_key( program.program_id, authority, spot_market_index ) response = await program.account["InsuranceFundStake"].fetch(if_stake_pk) - return cast(InsuranceFundStake, response) + return cast(InsuranceFundStakeAccount, response) async def get_user_stats_account( program: Program, authority: Pubkey, -) -> UserStats: +) -> UserStatsAccount: user_stats_public_key = get_user_stats_account_public_key( program.program_id, authority, ) response = await program.account["UserStats"].fetch(user_stats_public_key) - return cast(UserStats, response) + return cast(UserStatsAccount, response) async def get_user_account_and_slot( program: Program, user_public_key: Pubkey, -) -> DataAndSlot[User]: +) -> DataAndSlot[UserAccount]: return await get_account_data_and_slot(user_public_key, program) async def get_user_account( program: Program, user_public_key: Pubkey, -) -> User: +) -> UserAccount: return (await get_user_account_and_slot(program, user_public_key)).data async def get_perp_market_account_and_slot( program: Program, market_index: int -) -> Optional[DataAndSlot[PerpMarket]]: +) -> Optional[DataAndSlot[PerpMarketAccount]]: perp_market_public_key = get_perp_market_public_key( program.program_id, market_index ) return await get_account_data_and_slot(perp_market_public_key, program) -async def get_perp_market_account(program: Program, market_index: int) -> PerpMarket: +async def get_perp_market_account( + program: Program, market_index: int +) -> PerpMarketAccount: return (await get_perp_market_account_and_slot(program, market_index)).data @@ -98,7 +100,7 @@ async def get_all_perp_market_accounts(program: Program) -> list[ProgramAccount] async def get_spot_market_account_and_slot( program: Program, spot_market_index: int -) -> DataAndSlot[SpotMarket]: +) -> DataAndSlot[SpotMarketAccount]: spot_market_public_key = get_spot_market_public_key( program.program_id, spot_market_index ) @@ -107,7 +109,7 @@ async def get_spot_market_account_and_slot( async def get_spot_market_account( program: Program, spot_market_index: int -) -> SpotMarket: +) -> SpotMarketAccount: return (await get_spot_market_account_and_slot(program, spot_market_index)).data diff --git a/src/driftpy/accounts/polling/drift_client.py b/src/driftpy/accounts/polling/drift_client.py index 64b86be0..a170cc78 100644 --- a/src/driftpy/accounts/polling/drift_client.py +++ b/src/driftpy/accounts/polling/drift_client.py @@ -10,7 +10,13 @@ from driftpy.accounts.bulk_account_loader import BulkAccountLoader from driftpy.accounts.oracle import get_oracle_decode_fn from driftpy.addresses import get_state_public_key -from driftpy.types import PerpMarket, SpotMarket, OraclePriceData, State, OracleSource +from driftpy.types import ( + PerpMarketAccount, + SpotMarketAccount, + OraclePriceData, + StateAccount, + OracleSource, +) class PollingDriftClientAccountSubscriber(DriftClientAccountSubscriber): @@ -24,7 +30,7 @@ def __init__( self.is_subscribed = False self.callbacks: dict[str, int] = {} - self.state: Optional[DataAndSlot[State]] = None + self.state: Optional[DataAndSlot[StateAccount]] = None self.perp_markets = {} self.spot_markets = {} self.oracle = {} @@ -143,17 +149,17 @@ def unsubscribe(self): ) self.callbacks.clear() - async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]: + async def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]: return self.state async def get_perp_market_and_slot( self, market_index: int - ) -> Optional[DataAndSlot[PerpMarket]]: + ) -> Optional[DataAndSlot[PerpMarketAccount]]: return self.perp_markets.get(market_index) async def get_spot_market_and_slot( self, market_index: int - ) -> Optional[DataAndSlot[SpotMarket]]: + ) -> Optional[DataAndSlot[SpotMarketAccount]]: return self.spot_markets.get(market_index) async def get_oracle_price_data_and_slot( diff --git a/src/driftpy/accounts/polling/user.py b/src/driftpy/accounts/polling/user.py index 39be17e7..eefc5ae8 100644 --- a/src/driftpy/accounts/polling/user.py +++ b/src/driftpy/accounts/polling/user.py @@ -6,7 +6,7 @@ from driftpy.accounts import UserAccountSubscriber, DataAndSlot from driftpy.accounts.bulk_account_loader import BulkAccountLoader -from driftpy.types import User +from driftpy.types import UserAccount class PollingUserAccountSubscriber(UserAccountSubscriber): @@ -19,7 +19,7 @@ def __init__( self.bulk_account_loader = bulk_account_loader self.program = program self.user_account_pubkey = user_account_pubkey - self.data_and_slot: Optional[DataAndSlot[User]] = None + self.data_and_slot: Optional[DataAndSlot[UserAccount]] = None self.decode = self.program.coder.accounts.decode self.callback_id = None @@ -53,7 +53,7 @@ def _account_loader_callback(self, buffer: bytes, slot: int): async def fetch(self): await self.bulk_account_loader.load() - def _update_data(self, new_data: Optional[DataAndSlot[User]]): + def _update_data(self, new_data: Optional[DataAndSlot[UserAccount]]): if new_data is None: return @@ -70,5 +70,5 @@ def unsubscribe(self): self.callback_id = None - async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]: + async def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]: return self.data_and_slot diff --git a/src/driftpy/accounts/types.py b/src/driftpy/accounts/types.py index 8ddd0c64..b5fcc67d 100644 --- a/src/driftpy/accounts/types.py +++ b/src/driftpy/accounts/types.py @@ -5,12 +5,12 @@ from solders.pubkey import Pubkey from driftpy.types import ( - PerpMarket, - SpotMarket, + PerpMarketAccount, + SpotMarketAccount, OracleSource, - User, + UserAccount, OraclePriceData, - State, + StateAccount, ) T = TypeVar("T") @@ -32,19 +32,19 @@ def unsubscribe(self): pass @abstractmethod - async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]: + async def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]: pass @abstractmethod async def get_perp_market_and_slot( self, market_index: int - ) -> Optional[DataAndSlot[PerpMarket]]: + ) -> Optional[DataAndSlot[PerpMarketAccount]]: pass @abstractmethod async def get_spot_market_and_slot( self, market_index: int - ) -> Optional[DataAndSlot[SpotMarket]]: + ) -> Optional[DataAndSlot[SpotMarketAccount]]: pass @abstractmethod @@ -64,5 +64,5 @@ def unsubscribe(self): pass @abstractmethod - async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]: + async def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]: pass diff --git a/src/driftpy/accounts/ws/drift_client.py b/src/driftpy/accounts/ws/drift_client.py index 3cec6b65..68908224 100644 --- a/src/driftpy/accounts/ws/drift_client.py +++ b/src/driftpy/accounts/ws/drift_client.py @@ -6,7 +6,12 @@ from typing import Optional from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber -from driftpy.types import PerpMarket, SpotMarket, OraclePriceData, State +from driftpy.types import ( + PerpMarketAccount, + SpotMarketAccount, + OraclePriceData, + StateAccount, +) from driftpy.addresses import * @@ -26,7 +31,7 @@ def __init__(self, program: Program, commitment: Commitment = "confirmed"): async def subscribe(self): state_public_key = get_state_public_key(self.program.program_id) - self.state_subscriber = WebsocketAccountSubscriber[State]( + self.state_subscriber = WebsocketAccountSubscriber[StateAccount]( state_public_key, self.program, self.commitment ) await self.state_subscriber.subscribe() @@ -44,7 +49,7 @@ async def subscribe_to_spot_market(self, market_index: int): spot_market_public_key = get_spot_market_public_key( self.program.program_id, market_index ) - spot_market_subscriber = WebsocketAccountSubscriber[SpotMarket]( + spot_market_subscriber = WebsocketAccountSubscriber[SpotMarketAccount]( spot_market_public_key, self.program, self.commitment ) await spot_market_subscriber.subscribe() @@ -60,7 +65,7 @@ async def subscribe_to_perp_market(self, market_index: int): perp_market_public_key = get_perp_market_public_key( self.program.program_id, market_index ) - perp_market_subscriber = WebsocketAccountSubscriber[PerpMarket]( + perp_market_subscriber = WebsocketAccountSubscriber[PerpMarketAccount]( perp_market_public_key, self.program, self.commitment ) await perp_market_subscriber.subscribe() @@ -87,17 +92,17 @@ async def subscribe_to_oracle(self, oracle: Pubkey, oracle_source: OracleSource) await oracle_subscriber.subscribe() self.oracle_subscribers[str(oracle)] = oracle_subscriber - async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]: + async def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]: return self.state_subscriber.data_and_slot async def get_perp_market_and_slot( self, market_index: int - ) -> Optional[DataAndSlot[PerpMarket]]: + ) -> Optional[DataAndSlot[PerpMarketAccount]]: return self.perp_market_subscribers[market_index].data_and_slot async def get_spot_market_and_slot( self, market_index: int - ) -> Optional[DataAndSlot[SpotMarket]]: + ) -> Optional[DataAndSlot[SpotMarketAccount]]: return self.spot_market_subscribers[market_index].data_and_slot async def get_oracle_price_data_and_slot( diff --git a/src/driftpy/accounts/ws/user.py b/src/driftpy/accounts/ws/user.py index 425958ed..5e63a173 100644 --- a/src/driftpy/accounts/ws/user.py +++ b/src/driftpy/accounts/ws/user.py @@ -1,14 +1,14 @@ from typing import Optional from driftpy.accounts import DataAndSlot -from driftpy.types import User +from driftpy.types import UserAccount from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber from driftpy.accounts.types import UserAccountSubscriber class WebsocketUserAccountSubscriber( - WebsocketAccountSubscriber[User], UserAccountSubscriber + WebsocketAccountSubscriber[UserAccount], UserAccountSubscriber ): - async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]: + async def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]: return self.data_and_slot diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 7cc04dd4..6e2be3fe 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -25,7 +25,6 @@ QUOTE_SPOT_MARKET_INDEX, ) from driftpy.drift_user import DriftUser -from driftpy.sdk_types import * from driftpy.accounts import * from driftpy.constants.config import DriftEnv, DRIFT_PROGRAM_ID, configs @@ -153,7 +152,7 @@ def get_user(self, sub_account_id=0) -> DriftUser: return self.users[sub_account_id] - async def get_user_account(self, sub_account_id=0) -> User: + async def get_user_account(self, sub_account_id=0) -> UserAccount: return await self.get_user(sub_account_id).get_user() def switch_active_user(self, sub_account_id: int): @@ -165,17 +164,17 @@ def get_state_public_key(self): def get_user_stats_public_key(self): return get_user_stats_account_public_key(self.program_id, self.authority) - async def get_state(self) -> Optional[State]: + async def get_state(self) -> Optional[StateAccount]: state_and_slot = await self.account_subscriber.get_state_account_and_slot() return getattr(state_and_slot, "data", None) - async def get_perp_market(self, market_index: int) -> Optional[PerpMarket]: + async def get_perp_market(self, market_index: int) -> Optional[PerpMarketAccount]: perp_market_and_slot = await self.account_subscriber.get_perp_market_and_slot( market_index ) return getattr(perp_market_and_slot, "data", None) - async def get_spot_market(self, market_index: int) -> Optional[SpotMarket]: + async def get_spot_market(self, market_index: int) -> Optional[SpotMarketAccount]: spot_market_and_slot = await self.account_subscriber.get_spot_market_and_slot( market_index ) @@ -320,7 +319,7 @@ def get_initialize_user_instructions( async def get_remaining_accounts( self, - user_accounts: list[User] = (), + user_accounts: list[UserAccount] = (), writable_perp_market_indexes: list[int] = (), writable_spot_market_indexes: list[int] = (), readable_spot_market_indexes: list[int] = (), @@ -404,7 +403,7 @@ async def add_spot_market_to_remaining_account_maps( ) async def get_remaining_accounts_for_users( - self, user_accounts: list[User] + self, user_accounts: list[UserAccount] ) -> (dict[str, AccountMeta], dict[int, AccountMeta], dict[int, AccountMeta]): oracle_map = {} spot_market_map = {} diff --git a/src/driftpy/drift_user.py b/src/driftpy/drift_user.py index b1b739e2..4ef97128 100644 --- a/src/driftpy/drift_user.py +++ b/src/driftpy/drift_user.py @@ -52,25 +52,25 @@ def unsubscribe(self): self.account_subscriber.unsubscribe() async def get_spot_oracle_data( - self, spot_market: SpotMarket + self, spot_market: SpotMarketAccount ) -> Optional[OraclePriceData]: return await self.drift_client.get_oracle_price_data(spot_market.oracle) async def get_perp_oracle_data( - self, perp_market: PerpMarket + self, perp_market: PerpMarketAccount ) -> Optional[OraclePriceData]: return await self.drift_client.get_oracle_price_data(perp_market.amm.oracle) - async def get_state(self) -> State: + async def get_state(self) -> StateAccount: return await self.drift_client.get_state() - async def get_spot_market(self, market_index: int) -> SpotMarket: + async def get_spot_market(self, market_index: int) -> SpotMarketAccount: return await self.drift_client.get_spot_market(market_index) - async def get_perp_market(self, market_index: int) -> PerpMarket: + async def get_perp_market(self, market_index: int) -> PerpMarketAccount: return await self.drift_client.get_perp_market(market_index) - async def get_user(self) -> User: + async def get_user(self) -> UserAccount: return (await self.account_subscriber.get_user_account_and_slot()).data async def get_open_orders( @@ -79,7 +79,7 @@ async def get_open_orders( # market_index: int, # position_direction: PositionDirection ): - user: User = await self.get_user() + user: UserAccount = await self.get_user() return user.orders async def get_spot_market_liability( diff --git a/src/driftpy/math/amm.py b/src/driftpy/math/amm.py index ccb6be43..af9b572a 100644 --- a/src/driftpy/math/amm.py +++ b/src/driftpy/math/amm.py @@ -5,8 +5,7 @@ AMM_TIMES_PEG_TO_QUOTE_PRECISION_RATIO, QUOTE_PRECISION, ) -from driftpy.sdk_types import AssetType -from driftpy.types import PositionDirection, SwapDirection, AMM +from driftpy.types import PositionDirection, SwapDirection, AMM, AssetType def calculate_peg_from_target_price( diff --git a/src/driftpy/math/funding.py b/src/driftpy/math/funding.py index b1de3d96..afca4ed9 100644 --- a/src/driftpy/math/funding.py +++ b/src/driftpy/math/funding.py @@ -1,5 +1,5 @@ from driftpy.types import ( - PerpMarket, + PerpMarketAccount, ) from driftpy.constants.numeric_constants import ( @@ -14,7 +14,7 @@ ) -def calculate_long_short_funding(market: PerpMarket): +def calculate_long_short_funding(market: PerpMarketAccount): sym = calculate_symmetric_funding(market) capped = calculate_capped_funding(market) if market.base_asset_amount > 0: @@ -25,7 +25,7 @@ def calculate_long_short_funding(market: PerpMarket): return [sym, sym] -def calculate_capped_funding(market: PerpMarket): +def calculate_capped_funding(market: PerpMarketAccount): smaller_side = min( abs(market.amm.base_asset_amount_short), market.amm.base_asset_amount_long ) @@ -52,7 +52,7 @@ def calculate_capped_funding(market: PerpMarket): return capped_funding -def calculate_symmetric_funding(market: PerpMarket): +def calculate_symmetric_funding(market: PerpMarketAccount): next_funding = calculate_oracle_mark_spread_owed(market) next_funding /= market.amm.last_oracle_price_twap * 100 @@ -60,11 +60,11 @@ def calculate_symmetric_funding(market: PerpMarket): return next_funding -def calculate_oracle_mark_spread_owed(market: PerpMarket): +def calculate_oracle_mark_spread_owed(market: PerpMarketAccount): return (market.amm.last_mark_price_twap - market.amm.last_oracle_price_twap) / 24 -def calculate_funding_fee_pool(market: PerpMarket): +def calculate_funding_fee_pool(market: PerpMarketAccount): fee_pool = ( market.amm.total_fee_minus_distributions - market.amm.total_fee / 2 ) / QUOTE_PRECISION diff --git a/src/driftpy/math/margin.py b/src/driftpy/math/margin.py index 6463b24c..11f96120 100644 --- a/src/driftpy/math/margin.py +++ b/src/driftpy/math/margin.py @@ -33,7 +33,7 @@ class MarginCategory(Enum): def calculate_asset_weight( amount, - spot_market: SpotMarket, + spot_market: SpotMarketAccount, margin_category: MarginCategory, ): size_precision = 10**spot_market.decimals @@ -86,7 +86,9 @@ def calculate_size_premium_liability_weight( return max_liability_weight -def calculate_net_user_pnl(perp_market: PerpMarket, oracle_data: OraclePriceData): +def calculate_net_user_pnl( + perp_market: PerpMarketAccount, oracle_data: OraclePriceData +): net_user_position_value = ( perp_market.amm.base_asset_amount_with_amm * oracle_data.price @@ -101,7 +103,9 @@ def calculate_net_user_pnl(perp_market: PerpMarket, oracle_data: OraclePriceData def calculate_net_user_pnl_imbalance( - perp_market: PerpMarket, spot_market: SpotMarket, oracle_data: OraclePriceData + perp_market: PerpMarketAccount, + spot_market: SpotMarketAccount, + oracle_data: OraclePriceData, ): user_pnl = calculate_net_user_pnl(perp_market, oracle_data) @@ -114,8 +118,8 @@ def calculate_net_user_pnl_imbalance( def calculate_unrealized_asset_weight( - perp_market: PerpMarket, - spot_market: SpotMarket, + perp_market: PerpMarketAccount, + spot_market: SpotMarketAccount, unrealized_pnl: int, margin_category: MarginCategory, oracle_data: OraclePriceData, @@ -146,7 +150,7 @@ def calculate_unrealized_asset_weight( def calculate_liability_weight( - balance_amount: int, spot_market: SpotMarket, margin_category: MarginCategory + balance_amount: int, spot_market: SpotMarketAccount, margin_category: MarginCategory ) -> int: size_precision = 10**spot_market.decimals if size_precision > AMM_RESERVE_PRECISION: @@ -180,7 +184,10 @@ def calculate_liability_weight( def get_spot_asset_value( - amount: int, oracle_data, spot_market: SpotMarket, margin_category: MarginCategory + amount: int, + oracle_data, + spot_market: SpotMarketAccount, + margin_category: MarginCategory, ): asset_value = get_token_value(amount, spot_market.decimals, oracle_data) @@ -192,7 +199,7 @@ def get_spot_asset_value( def calculate_market_margin_ratio( - market: PerpMarket, size: int, margin_category: MarginCategory + market: PerpMarketAccount, size: int, margin_category: MarginCategory ) -> int: match margin_category: case MarginCategory.INITIAL: @@ -212,7 +219,7 @@ def calculate_market_margin_ratio( def get_spot_liability_value( token_amount: int, oracle_data: OraclePriceData, - spot_market: SpotMarket, + spot_market: SpotMarketAccount, margin_category: MarginCategory, liquidation_buffer: int = None, max_margin_ratio: int = None, diff --git a/src/driftpy/math/positions.py b/src/driftpy/math/positions.py index cbe63ea1..9c18f596 100644 --- a/src/driftpy/math/positions.py +++ b/src/driftpy/math/positions.py @@ -4,7 +4,7 @@ def get_worst_case_token_amounts( position: SpotPosition, - spot_market: SpotMarket, + spot_market: SpotMarketAccount, oracle_data, ): token_amount = get_signed_token_amount( @@ -35,7 +35,9 @@ def calculate_base_asset_value_with_oracle( ) -def calculate_position_funding_pnl(market: PerpMarket, perp_position: PerpPosition): +def calculate_position_funding_pnl( + market: PerpMarketAccount, perp_position: PerpPosition +): if perp_position.base_asset_amount == 0: return 0 @@ -57,7 +59,7 @@ def calculate_position_funding_pnl(market: PerpMarket, perp_position: PerpPositi def calculate_position_pnl_with_oracle( - market: PerpMarket, + market: PerpMarketAccount, perp_position: PerpPosition, oracle_data: OraclePriceData, with_funding=False, @@ -98,7 +100,9 @@ def is_available(position: PerpPosition): ) -def calculate_base_asset_value(market: PerpMarket, user_position: PerpPosition) -> int: +def calculate_base_asset_value( + market: PerpMarketAccount, user_position: PerpPosition +) -> int: if user_position.base_asset_amount == 0: return 0 @@ -135,7 +139,7 @@ def calculate_base_asset_value(market: PerpMarket, user_position: PerpPosition) def calculate_position_pnl( - market: PerpMarket, market_position: PerpPosition, with_funding=False + market: PerpMarketAccount, market_position: PerpPosition, with_funding=False ): pnl = 0.0 @@ -156,7 +160,9 @@ def calculate_position_pnl( return pnl -def calculate_position_funding_pnl(market: PerpMarket, market_position: PerpPosition): +def calculate_position_funding_pnl( + market: PerpMarketAccount, market_position: PerpPosition +): funding_pnl = 0.0 if market_position.base_asset_amount == 0: diff --git a/src/driftpy/math/spot_market.py b/src/driftpy/math/spot_market.py index 7587454c..8995a725 100644 --- a/src/driftpy/math/spot_market.py +++ b/src/driftpy/math/spot_market.py @@ -13,7 +13,7 @@ def get_signed_token_amount(amount, balance_type): def get_token_amount( - balance: int, spot_market: SpotMarket, balance_type: SpotBalanceType + balance: int, spot_market: SpotMarketAccount, balance_type: SpotBalanceType ) -> int: percision_decrease = 10 ** (19 - spot_market.decimals) diff --git a/src/driftpy/math/trade.py b/src/driftpy/math/trade.py index 5545f32b..e9d98645 100644 --- a/src/driftpy/math/trade.py +++ b/src/driftpy/math/trade.py @@ -20,14 +20,13 @@ AMM_TO_QUOTE_PRECISION_RATIO, ) -from driftpy.types import PositionDirection, PerpMarket, AMM -from driftpy.sdk_types import AssetType +from driftpy.types import PositionDirection, PerpMarketAccount, AMM, AssetType def calculate_trade_acquired_amounts( direction: PositionDirection, amount: int, - market: PerpMarket, + market: PerpMarketAccount, input_asset_type=AssetType, use_spread: boolean = True, ): @@ -73,7 +72,7 @@ def calculate_trade_acquired_amounts( def calculate_trade_slippage( direction: PositionDirection, amount: int, - market: PerpMarket, + market: PerpMarketAccount, input_asset_type: AssetType, use_spread: boolean = True, ): @@ -131,7 +130,7 @@ def calculate_trade_slippage( def calculate_target_price_trade( - market: PerpMarket, + market: PerpMarketAccount, target_price: float, output_asset_type: AssetType, use_spread: boolean = True, diff --git a/src/driftpy/math/user.py b/src/driftpy/math/user.py index becac84c..c05f05b3 100644 --- a/src/driftpy/math/user.py +++ b/src/driftpy/math/user.py @@ -1,7 +1,7 @@ from driftpy.types import ( - User, + UserAccount, PerpPosition, - PerpMarket, + PerpMarketAccount, ) from collections.abc import Mapping @@ -15,7 +15,7 @@ def calculate_unrealised_pnl( user_position: list[PerpPosition], - markets: Mapping[int, PerpMarket], + markets: Mapping[int, PerpMarketAccount], market_index: int = None, ) -> int: pnl = 0 @@ -30,7 +30,7 @@ def calculate_unrealised_pnl( def get_total_position_value( user_position: list[PerpPosition], - markets: Mapping[int, PerpMarket], + markets: Mapping[int, PerpMarketAccount], ): value = 0 for position in user_position: @@ -42,7 +42,7 @@ def get_total_position_value( def get_position_value( user_position: list[PerpPosition], - markets: Mapping[int, PerpMarket], + markets: Mapping[int, PerpMarketAccount], market_index: int, ): assert market_index is None or int(market_index) >= 0 @@ -55,12 +55,16 @@ def get_position_value( return value -def get_total_collateral(user_account: User, markets: Mapping[int, PerpMarket]): +def get_total_collateral( + user_account: UserAccount, markets: Mapping[int, PerpMarketAccount] +): collateral = user_account.collateral return collateral + calculate_unrealised_pnl(user_account.positions, markets) -def get_margin_ratio(user_account: User, markets: Mapping[int, PerpMarket]): +def get_margin_ratio( + user_account: UserAccount, markets: Mapping[int, PerpMarketAccount] +): tpv = get_total_position_value(user_account.positions, markets) if tpv > 0: return get_total_collateral(user_account, markets) / tpv @@ -68,20 +72,24 @@ def get_margin_ratio(user_account: User, markets: Mapping[int, PerpMarket]): return np.nan -def get_leverage(user_account: User, markets: Mapping[int, PerpMarket]): +def get_leverage(user_account: UserAccount, markets: Mapping[int, PerpMarketAccount]): return get_total_position_value( user_account.positions, markets ) / get_total_collateral(user_account, markets) -def get_free_collateral(user_account: User, markets: Mapping[int, PerpMarket]): +def get_free_collateral( + user_account: UserAccount, markets: Mapping[int, PerpMarketAccount] +): return get_total_collateral(user_account, markets) - ( get_margin_requirement(user_account.positions, markets, "initial") ) def get_margin_requirement( - user_position: list[PerpPosition], markets: Mapping[int, PerpMarket], kind: str + user_position: list[PerpPosition], + markets: Mapping[int, PerpMarketAccount], + kind: str, ): assert kind in ["initial", "partial", "maintenance"] @@ -101,14 +109,18 @@ def get_margin_requirement( return value -def can_be_liquidated(user_account: User, markets: Mapping[int, PerpMarket]): +def can_be_liquidated( + user_account: UserAccount, markets: Mapping[int, PerpMarketAccount] +): return get_total_collateral(user_account, markets) < ( get_margin_requirement("partial") ) def liquidation_price( - user_account: User, markets: Mapping[int, PerpMarket], market_index: int + user_account: UserAccount, + markets: Mapping[int, PerpMarketAccount], + market_index: int, ): # todo diff --git a/src/driftpy/setup/helpers.py b/src/driftpy/setup/helpers.py index b66402db..cc7ac2b9 100644 --- a/src/driftpy/setup/helpers.py +++ b/src/driftpy/setup/helpers.py @@ -25,7 +25,6 @@ ) from solana.transaction import Signature -from driftpy.sdk_types import AssetType from driftpy.types import * from driftpy.math.amm import calculate_amm_reserves_after_swap, calculate_price @@ -33,7 +32,7 @@ async def adjust_oracle_pretrade( baa: int, position_direction: PositionDirection, - market: PerpMarket, + market: PerpMarketAccount, oracle_program: Program, ): price = calculate_price( diff --git a/src/driftpy/types.py b/src/driftpy/types.py index cb4ecea5..d7c55e10 100644 --- a/src/driftpy/types.py +++ b/src/driftpy/types.py @@ -287,12 +287,6 @@ class UserStatus: REDUCE_ONLY = constructor() -@_rust_enum -class AssetType: - BASE = constructor() - QUOTE = constructor() - - @_rust_enum class OrderStatus: INIT = constructor() @@ -559,7 +553,7 @@ class Order: @dataclass -class PhoenixV1FulfillmentConfig: +class PhoenixV1FulfillmentConfigAccount: pubkey: Pubkey phoenix_program_id: Pubkey phoenix_log_authority: Pubkey @@ -573,7 +567,7 @@ class PhoenixV1FulfillmentConfig: @dataclass -class SerumV3FulfillmentConfig: +class SerumV3FulfillmentConfigAccount: pubkey: Pubkey serum_program_id: Pubkey serum_market: Pubkey @@ -601,7 +595,7 @@ class InsuranceClaim: @dataclass -class PerpMarket: +class PerpMarketAccount: pubkey: Pubkey amm: AMM pnl_pool: PoolBalance @@ -656,7 +650,7 @@ class InsuranceFund: @dataclass -class SpotMarket: +class SpotMarketAccount: pubkey: Pubkey oracle: Pubkey mint: Pubkey @@ -713,7 +707,7 @@ class SpotMarket: @dataclass -class State: +class StateAccount: admin: Pubkey whitelist_mint: Pubkey discount_mint: Pubkey @@ -759,7 +753,7 @@ class PerpPosition: @dataclass -class User: +class UserAccount: authority: Pubkey delegate: Pubkey name: list[int] @@ -800,7 +794,7 @@ class UserFees: @dataclass -class UserStats: +class UserStatsAccount: authority: Pubkey referrer: Pubkey fees: UserFees @@ -883,7 +877,7 @@ class SpotBankruptcyRecord: @dataclass -class InsuranceFundStake: +class InsuranceFundStakeAccount: authority: Pubkey if_shares: int last_withdraw_request_shares: int @@ -897,7 +891,7 @@ class InsuranceFundStake: @dataclass -class ProtocolIfSharesTransferConfig: +class ProtocolIfSharesTransferConfigAccount: whitelisted_signers: list[Pubkey] max_transfer_per_epoch: int current_epoch_transfer: int @@ -906,7 +900,7 @@ class ProtocolIfSharesTransferConfig: @dataclass -class ReferrerName: +class ReferrerNameAccount: authority: Pubkey user: Pubkey user_stats: Pubkey @@ -927,3 +921,15 @@ class OraclePriceData: class TxParams: compute_units: Optional[int] compute_units_price: Optional[int] + + +@_rust_enum +class AssetType: + QUOTE = constructor() + BASE = constructor() + + +@dataclass +class MakerInfo: + maker: Pubkey + order: Order diff --git a/tests/test.py b/tests/test.py index faf2958a..f45e8ca4 100644 --- a/tests/test.py +++ b/tests/test.py @@ -30,10 +30,10 @@ from driftpy.addresses import * from driftpy.types import ( - User, + UserAccount, PositionDirection, OracleSource, - PerpMarket, + PerpMarketAccount, OrderType, OrderParams, # SwapDirection, @@ -181,7 +181,7 @@ async def test_market( ): program = drift_client.program market_oracle_public_key = initialized_market - market: PerpMarket = await get_perp_market_account(program, 0) + market: PerpMarketAccount = await get_perp_market_account(program, 0) assert market.amm.oracle == market_oracle_public_key @@ -194,7 +194,7 @@ async def test_init_user( user_public_key = get_user_account_public_key( drift_client.program.program_id, drift_client.authority, 0 ) - user: User = await get_user_account(drift_client.program, user_public_key) + user: UserAccount = await get_user_account(drift_client.program, user_public_key) assert user.authority == drift_client.authority