diff --git a/src/driftpy/admin.py b/src/driftpy/admin.py index 0811bbe2..52df5efd 100644 --- a/src/driftpy/admin.py +++ b/src/driftpy/admin.py @@ -26,32 +26,6 @@ class Admin(DriftClient): - @staticmethod - def from_config( - config: Config, - provider: Provider, - authority: Keypair = None, - admin: bool = False, - ): - # read the idl - file = Path(str(driftpy.__path__[0]) + "/idl/drift.json") - with file.open() as f: - idl_dict = json.load(f) - idl = Idl.from_json(idl_dict) - - # create the program - program = Program( - idl, - config.drift_client_program_id, - provider, - ) - - drift_client = Admin(program, authority) - drift_client.config = config - drift_client.idl = idl - - return drift_client - async def initialize( self, usdc_mint: Pubkey, diff --git a/src/driftpy/constants/config.py b/src/driftpy/constants/config.py index 0329373d..a19b1395 100644 --- a/src/driftpy/constants/config.py +++ b/src/driftpy/constants/config.py @@ -1,14 +1,19 @@ +from typing import Literal + from driftpy.constants.banks import devnet_banks, mainnet_banks, Bank from driftpy.constants.markets import devnet_markets, mainnet_markets, Market from dataclasses import dataclass from solders.pubkey import Pubkey +DriftEnv = Literal["devnet", "mainnet"] + +DRIFT_PROGRAM_ID = Pubkey.from_string("dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH") + @dataclass class Config: - env: str + env: DriftEnv pyth_oracle_mapping_address: Pubkey - drift_client_program_id: Pubkey usdc_mint_address: Pubkey default_http: str default_ws: str @@ -22,9 +27,6 @@ class Config: pyth_oracle_mapping_address=Pubkey.from_string( "BmA9Z6FjioHJPpjT39QazZyhDRUdZy2ezwx4GiDdE2u2" ), - drift_client_program_id=Pubkey.from_string( - "dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH" - ), usdc_mint_address=Pubkey.from_string( "8zGuJQqwhZafTah7Uc7Z4tXRnguqkn5KLFAP8oV6PHe2" ), @@ -38,9 +40,6 @@ class Config: pyth_oracle_mapping_address=Pubkey.from_string( "AHtgzX45WTKfkPG53L6WYhGEXwQkN1BVknET3sVsLL8J" ), - drift_client_program_id=Pubkey.from_string( - "dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH" - ), usdc_mint_address=Pubkey.from_string( "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v" ), diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index cf609ac9..971eca0b 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -1,6 +1,4 @@ from solders.pubkey import Pubkey -import json -from typing import Optional from solders.keypair import Keypair from solana.transaction import Transaction from solders.transaction import VersionedTransaction @@ -9,10 +7,11 @@ from solders.instruction import Instruction from solders.system_program import ID from solders.sysvar import RENT +from solana.rpc.async_api import AsyncClient from solana.transaction import AccountMeta from solders.compute_budget import set_compute_unit_limit, set_compute_unit_price from spl.token.constants import TOKEN_PROGRAM_ID -from anchorpy import Program, Context, Idl, Provider +from anchorpy import Program, Context, Idl, Provider, Wallet from struct import pack_into from pathlib import Path @@ -21,20 +20,15 @@ from driftpy.constants.numeric_constants import ( QUOTE_SPOT_MARKET_INDEX, ) -from driftpy.addresses import * from driftpy.drift_user import DriftUser from driftpy.sdk_types import * -from driftpy.types import * from driftpy.accounts import * -from driftpy.constants.config import Config +from driftpy.constants.config import Config, DriftEnv, DRIFT_PROGRAM_ID from typing import Union, Optional, List, Sequence from driftpy.math.positions import is_available, is_spot_position_available -from driftpy.accounts import DriftClientAccountSubscriber -from driftpy.accounts.ws import WebsocketDriftClientAccountSubscriber - DEFAULT_USER_NAME = "Main Account" @@ -45,8 +39,10 @@ class DriftClient: def __init__( self, - program: Program, - signer: Keypair = None, + connection: AsyncClient, + wallet: Union[Keypair, Wallet], + env: DriftEnv = "mainnet", + program_id: Optional[Pubkey] = DRIFT_PROGRAM_ID, authority: Pubkey = None, account_subscription: Optional[ AccountSubscriptionConfig @@ -62,18 +58,27 @@ def __init__( program (Program): Drift anchor program (see from_config on how to initialize it) authority (Keypair, optional): Authority of all txs - if None will default to the Anchor Provider.Wallet Keypair. """ - self.program = program - self.program_id = program.program_id + file = Path(str(driftpy.__path__[0]) + "/idl/drift.json") + with file.open() as f: + raw = file.read_text() + idl = Idl.from_json(raw) + + provider = Provider(connection, wallet) + self.program_id = program_id + self.program = Program( + idl, + self.program_id, + provider, + ) - if signer is None: - signer = program.provider.wallet.payer + if isinstance(wallet, Keypair): + wallet = Wallet(wallet) if authority is None: - authority = signer.pubkey() + authority = wallet.public_key - self.signer = signer + self.wallet = wallet self.authority = authority - self.signers = [self.signer] self.usdc_ata = None self.spot_market_atas = {} @@ -99,39 +104,6 @@ def __init__( self.tx_version = tx_version if tx_version is not None else Legacy - @staticmethod - def from_config(config: Config, provider: Provider, authority: Keypair = None): - """Initializes the drift client object from a Config - - Args: - config (Config): the config to initialize form - provider (Provider): anchor provider - authority (Keypair, optional): _description_. Defaults to None. - - Returns: - DriftClient - : the drift client object - """ - # read the idl - file = Path(str(driftpy.__path__[0]) + "/idl/drift.json") - print(file) - with file.open() as f: - raw = file.read_text() - idl = Idl.from_json(raw) - - # create the program - program = Program( - idl, - config.drift_client_program_id, - provider, - ) - - drift_client = DriftClient(program, authority) - drift_client.config = config - drift_client.idl = idl - - return drift_client - async def subscribe(self): await self.account_subscriber.subscribe() for sub_account_id in self.sub_account_ids: @@ -220,16 +192,18 @@ async def send_ixs( tx = Transaction( instructions=ixs, recent_blockhash=latest_blockhash, - fee_payer=self.signer.pubkey(), + fee_payer=self.wallet.public_key, ) - tx.sign_partial(self.signer) + tx.sign_partial(self.wallet.payer) if signers is not None: [tx.sign_partial(signer) for signer in signers] elif self.tx_version == 0: - msg = MessageV0.try_compile(self.signer.pubkey(), ixs, [], latest_blockhash) - tx = VersionedTransaction(msg, [self.signer]) + msg = MessageV0.try_compile( + self.wallet.public_key, ixs, [], latest_blockhash + ) + tx = VersionedTransaction(msg, [self.wallet.payer]) else: raise NotImplementedError("unknown tx version", self.tx_version) @@ -804,7 +778,7 @@ async def get_place_spot_order_ix( accounts={ "state": self.get_state_public_key(), "user": user_account_public_key, - "authority": self.signer.pubkey(), + "authority": self.wallet.public_key, }, remaining_accounts=remaining_accounts, ), @@ -836,7 +810,7 @@ async def get_place_spot_orders_ix( accounts={ "state": self.get_state_public_key(), "user": self.get_user_account_public_key(sub_account_id), - "authority": self.signer.pubkey(), + "authority": self.wallet.public_key, }, remaining_accounts=remaining_accounts, ), @@ -849,7 +823,7 @@ async def get_place_spot_orders_ix( accounts={ "state": self.get_state_public_key(), "user": user_account_public_key, - "authority": self.signer.pubkey(), + "authority": self.wallet.public_key, }, remaining_accounts=remaining_accounts, ), @@ -888,7 +862,7 @@ async def get_place_perp_order_ix( accounts={ "state": self.get_state_public_key(), "user": user_account_public_key, - "authority": self.signer.pubkey(), + "authority": self.wallet.public_key, }, remaining_accounts=remaining_accounts, ), @@ -916,7 +890,7 @@ async def get_place_perp_orders_ix( accounts={ "state": self.get_state_public_key(), "user": self.get_user_account_public_key(sub_account_id), - "authority": self.signer.pubkey(), + "authority": self.wallet.public_key, }, remaining_accounts=remaining_accounts, ), @@ -929,7 +903,7 @@ async def get_place_perp_orders_ix( accounts={ "state": self.get_state_public_key(), "user": user_account_public_key, - "authority": self.signer.pubkey(), + "authority": self.wallet.public_key, }, remaining_accounts=remaining_accounts, ), diff --git a/tests/test.py b/tests/test.py index e1aa263a..f8ec1abc 100644 --- a/tests/test.py +++ b/tests/test.py @@ -87,7 +87,11 @@ def provider(program: Program) -> Provider: @async_fixture(scope="session") async def drift_client(program: Program, usdc_mint: Keypair) -> Admin: - admin = Admin(program, account_subscription=AccountSubscriptionConfig("cached")) + admin = Admin( + program.provider.connection, + program.provider.wallet, + account_subscription=AccountSubscriptionConfig("cached"), + ) await admin.initialize(usdc_mint.pubkey(), admin_controls_prices=True) await admin.subscribe() return admin @@ -394,7 +398,7 @@ async def test_liq_perp( liq, _ = await _airdrop_user(drift_client.program.provider) liq_drift_client = DriftClient( - drift_client.program, + drift_client.program.provider.connection, liq, account_subscription=AccountSubscriptionConfig("cached"), )