diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 248df5f9..73910120 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -1,22 +1,25 @@ +import base64 import json import os import random import string +from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, cast -import anchorpy import requests -from anchorpy import Context, Idl, Program, Provider, Wallet +from anchorpy.program.context import Context +from anchorpy.program.core import Program +from anchorpy.provider import Provider, Wallet +from anchorpy_core.idl import Idl from deprecated import deprecated from solana.rpc.async_api import AsyncClient -from solana.rpc.commitment import Processed +from solana.rpc.commitment import Commitment, Processed from solana.rpc.types import TxOpts -from solana.transaction import AccountMeta from solders import system_program from solders.address_lookup_table_account import AddressLookupTableAccount from solders.compute_budget import set_compute_unit_limit, set_compute_unit_price -from solders.instruction import Instruction +from solders.instruction import AccountMeta, Instruction from solders.keypair import Keypair from solders.pubkey import Pubkey from solders.signature import Signature @@ -35,11 +38,33 @@ import driftpy from driftpy.account_subscription_config import AccountSubscriptionConfig -from driftpy.accounts import * -from driftpy.accounts.cache import CachedDriftClientAccountSubscriber -from driftpy.accounts.demo import DemoDriftClientAccountSubscriber +from driftpy.accounts import ( + DataAndSlot, + OracleInfo, + OraclePriceData, + PerpMarketAccount, + SpotMarketAccount, + StateAccount, + TxParams, + UserAccount, +) +from driftpy.accounts.cache.drift_client import CachedDriftClientAccountSubscriber +from driftpy.accounts.demo.drift_client import DemoDriftClientAccountSubscriber +from driftpy.accounts.get_accounts import get_perp_market_account from driftpy.address_lookup_table import get_address_lookup_table -from driftpy.addresses import get_sequencer_public_key_and_bump +from driftpy.addresses import ( + get_drift_client_signer_public_key, + get_insurance_fund_stake_public_key, + get_insurance_fund_vault_public_key, + get_protected_maker_mode_config_public_key, + get_sequencer_public_key_and_bump, + get_serum_signer_public_key, + get_spot_market_public_key, + get_spot_market_vault_public_key, + get_state_public_key, + get_user_account_public_key, + get_user_stats_account_public_key, +) from driftpy.constants import BASE_PRECISION, PRICE_PRECISION from driftpy.constants.config import ( DEVNET_SEQUENCER_PROGRAM_ID, @@ -59,6 +84,23 @@ from driftpy.name import encode_name from driftpy.tx.standard_tx_sender import StandardTxSender from driftpy.tx.types import TxSender, TxSigAndSlot +from driftpy.types import ( + MakerInfo, + MarketType, + ModifyOrderParams, + Order, + OrderParams, + OrderType, + PerpPosition, + PhoenixV1FulfillmentConfigAccount, + PositionDirection, + ReferrerInfo, + SequenceAccount, + SerumV3FulfillmentConfigAccount, + SpotPosition, + SwapReduceOnly, + is_variant, +) DEFAULT_USER_NAME = "Main Account" @@ -82,24 +124,22 @@ class DriftClient: def __init__( self, connection: AsyncClient, - wallet: Union[Keypair, Wallet], - env: DriftEnv = "mainnet", + wallet: Keypair | Wallet, + env: DriftEnv | None = "mainnet", opts: TxOpts = DEFAULT_TX_OPTIONS, - authority: Pubkey = None, - account_subscription: Optional[ - AccountSubscriptionConfig - ] = AccountSubscriptionConfig.default(), - perp_market_indexes: list[int] = None, - spot_market_indexes: list[int] = None, - oracle_infos: list[OracleInfo] = None, + authority: Pubkey | None = None, + account_subscription: AccountSubscriptionConfig = AccountSubscriptionConfig.default(), + perp_market_indexes: list[int] | None = None, + spot_market_indexes: list[int] | None = None, + oracle_infos: list[OracleInfo] | None = None, tx_params: Optional[TxParams] = None, tx_version: Optional[TransactionVersion] = None, - tx_sender: TxSender = None, + tx_sender: TxSender | None = None, active_sub_account_id: Optional[int] = None, sub_account_ids: Optional[list[int]] = None, market_lookup_table: Optional[Pubkey] = None, jito_params: Optional[JitoParams] = None, - tx_sender_blockhash_commitment: Optional[Commitment] = None, + tx_sender_blockhash_commitment: Commitment | None = None, enforce_tx_sequencing: bool = False, ): """Initializes the drift client object @@ -110,18 +150,15 @@ def __init__( """ self.connection = connection - file = Path(str(driftpy.__path__[0]) + "/idl/drift.json") - with file.open() as f: - raw = file.read_text() - idl = Idl.from_json(raw) + file = Path(str(next(iter(driftpy.__path__))) + "/idl/drift.json") + idl = Idl.from_json(file.read_text()) + + if isinstance(wallet, Keypair): + wallet = Wallet(wallet) provider = Provider(connection, wallet, opts) self.program_id = DRIFT_PROGRAM_ID - self.program = Program( - idl, - self.program_id, - provider, - ) + self.program = Program(idl, self.program_id, provider) if isinstance(wallet, Keypair): wallet = Wallet(wallet) @@ -140,8 +177,8 @@ def __init__( if sub_account_ids is not None else [self.active_sub_account_id] ) - self.users = {} - self.user_stats = {} + self.users: dict[int, DriftUser] = {} + self.user_stats: dict[Pubkey, DriftUserStats] = {} self.last_perp_market_seen_cache = {} self.last_spot_market_seen_cache = {} @@ -149,13 +186,18 @@ def __init__( self.account_subscriber = account_subscription.get_drift_client_subscriber( self.program, perp_market_indexes, spot_market_indexes, oracle_infos ) + if self.account_subscriber is None: + raise ValueError("No account subscriber found") + self.account_subscription_config = account_subscription - self.market_lookup_table = ( - market_lookup_table - if market_lookup_table is not None - else configs[env].market_lookup_table - ) + self.market_lookup_table = None + if env is not None: + self.market_lookup_table = ( + market_lookup_table + if market_lookup_table is not None + else configs[env].market_lookup_table + ) self.market_lookup_table_account: Optional[AddressLookupTableAccount] = None if tx_params is None: @@ -167,10 +209,10 @@ def __init__( self.enforce_tx_sequencing = enforce_tx_sequencing if self.enforce_tx_sequencing is True: - file = Path(str(driftpy.__path__[0]) + "/idl/sequence_enforcer.json") - with file.open() as f: - raw = file.read_text() - idl = Idl.from_json(raw) + file = Path( + str(next(iter(driftpy.__path__))) + "/idl/sequence_enforcer.json" + ) + idl = Idl.from_json(file.read_text()) provider = Provider(connection, wallet, opts) self.sequence_enforcer_pid = ( @@ -208,7 +250,7 @@ def __init__( blockhash_commitment=( tx_sender_blockhash_commitment if tx_sender_blockhash_commitment is not None - else "finalized" + else Commitment("finalized") ), ) if tx_sender is None @@ -3442,3 +3484,38 @@ async def load_sequence_info(self): self.sequence_bump_by_subaccount[subaccount] = bump self.sequence_initialized_by_subaccount[subaccount] = True self.sequence_address_by_subaccount[subaccount] = address + + async def update_user_protected_maker_orders( + self, + sub_account_id: int, + protected_orders: bool, + ): + return ( + await self.send_ixs( + [ + await self.get_update_user_protected_maker_orders_ix( + sub_account_id, protected_orders + ) + ] + ) + ).tx_sig + + async def get_update_user_protected_maker_orders_ix( + self, + sub_account_id: int, + protected_orders: bool, + ): + return self.program.instruction["update_user_protected_maker_orders"]( + sub_account_id, + protected_orders, + ctx=Context( + accounts={ + "state": self.get_state_public_key(), + "user": self.get_user_account_public_key(sub_account_id), + "authority": self.wallet.payer.pubkey(), + "protected_maker_mode_config": get_protected_maker_mode_config_public_key( + self.program_id + ), + } + ), + )