Skip to content

Commit

Permalink
drift_client init mirrors ts sdk
Browse files Browse the repository at this point in the history
make init mirror ts sdk
  • Loading branch information
crispheaney authored Nov 26, 2023
2 parents 5025f2d + 99ee891 commit e76cf12
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 97 deletions.
26 changes: 0 additions & 26 deletions src/driftpy/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,6 @@


class Admin(DriftClient):
@staticmethod
def from_config(
config: Config,
provider: Provider,
authority: Keypair = None,
admin: bool = False,
):
# read the idl
file = Path(str(driftpy.__path__[0]) + "/idl/drift.json")
with file.open() as f:
idl_dict = json.load(f)
idl = Idl.from_json(idl_dict)

# create the program
program = Program(
idl,
config.drift_client_program_id,
provider,
)

drift_client = Admin(program, authority)
drift_client.config = config
drift_client.idl = idl

return drift_client

async def initialize(
self,
usdc_mint: Pubkey,
Expand Down
15 changes: 7 additions & 8 deletions src/driftpy/constants/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from typing import Literal

from driftpy.constants.banks import devnet_banks, mainnet_banks, Bank
from driftpy.constants.markets import devnet_markets, mainnet_markets, Market
from dataclasses import dataclass
from solders.pubkey import Pubkey

DriftEnv = Literal["devnet", "mainnet"]

DRIFT_PROGRAM_ID = Pubkey.from_string("dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH")


@dataclass
class Config:
env: str
env: DriftEnv
pyth_oracle_mapping_address: Pubkey
drift_client_program_id: Pubkey
usdc_mint_address: Pubkey
default_http: str
default_ws: str
Expand All @@ -22,9 +27,6 @@ class Config:
pyth_oracle_mapping_address=Pubkey.from_string(
"BmA9Z6FjioHJPpjT39QazZyhDRUdZy2ezwx4GiDdE2u2"
),
drift_client_program_id=Pubkey.from_string(
"dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH"
),
usdc_mint_address=Pubkey.from_string(
"8zGuJQqwhZafTah7Uc7Z4tXRnguqkn5KLFAP8oV6PHe2"
),
Expand All @@ -38,9 +40,6 @@ class Config:
pyth_oracle_mapping_address=Pubkey.from_string(
"AHtgzX45WTKfkPG53L6WYhGEXwQkN1BVknET3sVsLL8J"
),
drift_client_program_id=Pubkey.from_string(
"dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH"
),
usdc_mint_address=Pubkey.from_string(
"EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"
),
Expand Down
96 changes: 35 additions & 61 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from solders.pubkey import Pubkey
import json
from typing import Optional
from solders.keypair import Keypair
from solana.transaction import Transaction
from solders.transaction import VersionedTransaction
Expand All @@ -9,10 +7,11 @@
from solders.instruction import Instruction
from solders.system_program import ID
from solders.sysvar import RENT
from solana.rpc.async_api import AsyncClient
from solana.transaction import AccountMeta
from solders.compute_budget import set_compute_unit_limit, set_compute_unit_price
from spl.token.constants import TOKEN_PROGRAM_ID
from anchorpy import Program, Context, Idl, Provider
from anchorpy import Program, Context, Idl, Provider, Wallet
from struct import pack_into
from pathlib import Path

Expand All @@ -21,20 +20,15 @@
from driftpy.constants.numeric_constants import (
QUOTE_SPOT_MARKET_INDEX,
)
from driftpy.addresses import *
from driftpy.drift_user import DriftUser
from driftpy.sdk_types import *
from driftpy.types import *
from driftpy.accounts import *

from driftpy.constants.config import Config
from driftpy.constants.config import Config, DriftEnv, DRIFT_PROGRAM_ID

from typing import Union, Optional, List, Sequence
from driftpy.math.positions import is_available, is_spot_position_available

from driftpy.accounts import DriftClientAccountSubscriber
from driftpy.accounts.ws import WebsocketDriftClientAccountSubscriber

DEFAULT_USER_NAME = "Main Account"


Expand All @@ -45,8 +39,10 @@ class DriftClient:

def __init__(
self,
program: Program,
signer: Keypair = None,
connection: AsyncClient,
wallet: Union[Keypair, Wallet],
env: DriftEnv = "mainnet",
program_id: Optional[Pubkey] = DRIFT_PROGRAM_ID,
authority: Pubkey = None,
account_subscription: Optional[
AccountSubscriptionConfig
Expand All @@ -62,18 +58,27 @@ def __init__(
program (Program): Drift anchor program (see from_config on how to initialize it)
authority (Keypair, optional): Authority of all txs - if None will default to the Anchor Provider.Wallet Keypair.
"""
self.program = program
self.program_id = program.program_id
file = Path(str(driftpy.__path__[0]) + "/idl/drift.json")
with file.open() as f:
raw = file.read_text()
idl = Idl.from_json(raw)

provider = Provider(connection, wallet)
self.program_id = program_id
self.program = Program(
idl,
self.program_id,
provider,
)

if signer is None:
signer = program.provider.wallet.payer
if isinstance(wallet, Keypair):
wallet = Wallet(wallet)

if authority is None:
authority = signer.pubkey()
authority = wallet.public_key

self.signer = signer
self.wallet = wallet
self.authority = authority
self.signers = [self.signer]
self.usdc_ata = None
self.spot_market_atas = {}

Expand All @@ -99,39 +104,6 @@ def __init__(

self.tx_version = tx_version if tx_version is not None else Legacy

@staticmethod
def from_config(config: Config, provider: Provider, authority: Keypair = None):
"""Initializes the drift client object from a Config
Args:
config (Config): the config to initialize form
provider (Provider): anchor provider
authority (Keypair, optional): _description_. Defaults to None.
Returns:
DriftClient
: the drift client object
"""
# read the idl
file = Path(str(driftpy.__path__[0]) + "/idl/drift.json")
print(file)
with file.open() as f:
raw = file.read_text()
idl = Idl.from_json(raw)

# create the program
program = Program(
idl,
config.drift_client_program_id,
provider,
)

drift_client = DriftClient(program, authority)
drift_client.config = config
drift_client.idl = idl

return drift_client

async def subscribe(self):
await self.account_subscriber.subscribe()
for sub_account_id in self.sub_account_ids:
Expand Down Expand Up @@ -220,16 +192,18 @@ async def send_ixs(
tx = Transaction(
instructions=ixs,
recent_blockhash=latest_blockhash,
fee_payer=self.signer.pubkey(),
fee_payer=self.wallet.public_key,
)

tx.sign_partial(self.signer)
tx.sign_partial(self.wallet.payer)

if signers is not None:
[tx.sign_partial(signer) for signer in signers]
elif self.tx_version == 0:
msg = MessageV0.try_compile(self.signer.pubkey(), ixs, [], latest_blockhash)
tx = VersionedTransaction(msg, [self.signer])
msg = MessageV0.try_compile(
self.wallet.public_key, ixs, [], latest_blockhash
)
tx = VersionedTransaction(msg, [self.wallet.payer])
else:
raise NotImplementedError("unknown tx version", self.tx_version)

Expand Down Expand Up @@ -804,7 +778,7 @@ async def get_place_spot_order_ix(
accounts={
"state": self.get_state_public_key(),
"user": user_account_public_key,
"authority": self.signer.pubkey(),
"authority": self.wallet.public_key,
},
remaining_accounts=remaining_accounts,
),
Expand Down Expand Up @@ -836,7 +810,7 @@ async def get_place_spot_orders_ix(
accounts={
"state": self.get_state_public_key(),
"user": self.get_user_account_public_key(sub_account_id),
"authority": self.signer.pubkey(),
"authority": self.wallet.public_key,
},
remaining_accounts=remaining_accounts,
),
Expand All @@ -849,7 +823,7 @@ async def get_place_spot_orders_ix(
accounts={
"state": self.get_state_public_key(),
"user": user_account_public_key,
"authority": self.signer.pubkey(),
"authority": self.wallet.public_key,
},
remaining_accounts=remaining_accounts,
),
Expand Down Expand Up @@ -888,7 +862,7 @@ async def get_place_perp_order_ix(
accounts={
"state": self.get_state_public_key(),
"user": user_account_public_key,
"authority": self.signer.pubkey(),
"authority": self.wallet.public_key,
},
remaining_accounts=remaining_accounts,
),
Expand Down Expand Up @@ -916,7 +890,7 @@ async def get_place_perp_orders_ix(
accounts={
"state": self.get_state_public_key(),
"user": self.get_user_account_public_key(sub_account_id),
"authority": self.signer.pubkey(),
"authority": self.wallet.public_key,
},
remaining_accounts=remaining_accounts,
),
Expand All @@ -929,7 +903,7 @@ async def get_place_perp_orders_ix(
accounts={
"state": self.get_state_public_key(),
"user": user_account_public_key,
"authority": self.signer.pubkey(),
"authority": self.wallet.public_key,
},
remaining_accounts=remaining_accounts,
),
Expand Down
8 changes: 6 additions & 2 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ def provider(program: Program) -> Provider:

@async_fixture(scope="session")
async def drift_client(program: Program, usdc_mint: Keypair) -> Admin:
admin = Admin(program, account_subscription=AccountSubscriptionConfig("cached"))
admin = Admin(
program.provider.connection,
program.provider.wallet,
account_subscription=AccountSubscriptionConfig("cached"),
)
await admin.initialize(usdc_mint.pubkey(), admin_controls_prices=True)
await admin.subscribe()
return admin
Expand Down Expand Up @@ -394,7 +398,7 @@ async def test_liq_perp(

liq, _ = await _airdrop_user(drift_client.program.provider)
liq_drift_client = DriftClient(
drift_client.program,
drift_client.program.provider.connection,
liq,
account_subscription=AccountSubscriptionConfig("cached"),
)
Expand Down

0 comments on commit e76cf12

Please sign in to comment.