Skip to content

Commit

Permalink
Update drift_client for typing
Browse files Browse the repository at this point in the history
  • Loading branch information
SinaKhalili committed Dec 24, 2024
1 parent 4a868e6 commit b34d51b
Showing 1 changed file with 119 additions and 42 deletions.
161 changes: 119 additions & 42 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import base64
import json
import os
import random
import string
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, cast

import anchorpy
import requests
from anchorpy import Context, Idl, Program, Provider, Wallet
from anchorpy.program.context import Context
from anchorpy.program.core import Program
from anchorpy.provider import Provider, Wallet
from anchorpy_core.idl import Idl
from deprecated import deprecated
from solana.rpc.async_api import AsyncClient
from solana.rpc.commitment import Processed
from solana.rpc.commitment import Commitment, Processed
from solana.rpc.types import TxOpts
from solana.transaction import AccountMeta
from solders import system_program
from solders.address_lookup_table_account import AddressLookupTableAccount
from solders.compute_budget import set_compute_unit_limit, set_compute_unit_price
from solders.instruction import Instruction
from solders.instruction import AccountMeta, Instruction
from solders.keypair import Keypair
from solders.pubkey import Pubkey
from solders.signature import Signature
Expand All @@ -35,11 +38,33 @@

import driftpy
from driftpy.account_subscription_config import AccountSubscriptionConfig
from driftpy.accounts import *
from driftpy.accounts.cache import CachedDriftClientAccountSubscriber
from driftpy.accounts.demo import DemoDriftClientAccountSubscriber
from driftpy.accounts import (
DataAndSlot,
OracleInfo,
OraclePriceData,
PerpMarketAccount,
SpotMarketAccount,
StateAccount,
TxParams,
UserAccount,
)
from driftpy.accounts.cache.drift_client import CachedDriftClientAccountSubscriber
from driftpy.accounts.demo.drift_client import DemoDriftClientAccountSubscriber
from driftpy.accounts.get_accounts import get_perp_market_account
from driftpy.address_lookup_table import get_address_lookup_table
from driftpy.addresses import get_sequencer_public_key_and_bump
from driftpy.addresses import (
get_drift_client_signer_public_key,
get_insurance_fund_stake_public_key,
get_insurance_fund_vault_public_key,
get_protected_maker_mode_config_public_key,
get_sequencer_public_key_and_bump,
get_serum_signer_public_key,
get_spot_market_public_key,
get_spot_market_vault_public_key,
get_state_public_key,
get_user_account_public_key,
get_user_stats_account_public_key,
)
from driftpy.constants import BASE_PRECISION, PRICE_PRECISION
from driftpy.constants.config import (
DEVNET_SEQUENCER_PROGRAM_ID,
Expand All @@ -59,6 +84,23 @@
from driftpy.name import encode_name
from driftpy.tx.standard_tx_sender import StandardTxSender
from driftpy.tx.types import TxSender, TxSigAndSlot
from driftpy.types import (
MakerInfo,
MarketType,
ModifyOrderParams,
Order,
OrderParams,
OrderType,
PerpPosition,
PhoenixV1FulfillmentConfigAccount,
PositionDirection,
ReferrerInfo,
SequenceAccount,
SerumV3FulfillmentConfigAccount,
SpotPosition,
SwapReduceOnly,
is_variant,
)

DEFAULT_USER_NAME = "Main Account"

Expand All @@ -82,24 +124,22 @@ class DriftClient:
def __init__(
self,
connection: AsyncClient,
wallet: Union[Keypair, Wallet],
env: DriftEnv = "mainnet",
wallet: Keypair | Wallet,
env: DriftEnv | None = "mainnet",
opts: TxOpts = DEFAULT_TX_OPTIONS,
authority: Pubkey = None,
account_subscription: Optional[
AccountSubscriptionConfig
] = AccountSubscriptionConfig.default(),
perp_market_indexes: list[int] = None,
spot_market_indexes: list[int] = None,
oracle_infos: list[OracleInfo] = None,
authority: Pubkey | None = None,
account_subscription: AccountSubscriptionConfig = AccountSubscriptionConfig.default(),
perp_market_indexes: list[int] | None = None,
spot_market_indexes: list[int] | None = None,
oracle_infos: list[OracleInfo] | None = None,
tx_params: Optional[TxParams] = None,
tx_version: Optional[TransactionVersion] = None,
tx_sender: TxSender = None,
tx_sender: TxSender | None = None,
active_sub_account_id: Optional[int] = None,
sub_account_ids: Optional[list[int]] = None,
market_lookup_table: Optional[Pubkey] = None,
jito_params: Optional[JitoParams] = None,
tx_sender_blockhash_commitment: Optional[Commitment] = None,
tx_sender_blockhash_commitment: Commitment | None = None,
enforce_tx_sequencing: bool = False,
):
"""Initializes the drift client object
Expand All @@ -110,18 +150,15 @@ def __init__(
"""
self.connection = connection

file = Path(str(driftpy.__path__[0]) + "/idl/drift.json")
with file.open() as f:
raw = file.read_text()
idl = Idl.from_json(raw)
file = Path(str(next(iter(driftpy.__path__))) + "/idl/drift.json")
idl = Idl.from_json(file.read_text())

if isinstance(wallet, Keypair):
wallet = Wallet(wallet)

provider = Provider(connection, wallet, opts)
self.program_id = DRIFT_PROGRAM_ID
self.program = Program(
idl,
self.program_id,
provider,
)
self.program = Program(idl, self.program_id, provider)

if isinstance(wallet, Keypair):
wallet = Wallet(wallet)
Expand All @@ -140,22 +177,27 @@ def __init__(
if sub_account_ids is not None
else [self.active_sub_account_id]
)
self.users = {}
self.user_stats = {}
self.users: dict[int, DriftUser] = {}
self.user_stats: dict[Pubkey, DriftUserStats] = {}

self.last_perp_market_seen_cache = {}
self.last_spot_market_seen_cache = {}

self.account_subscriber = account_subscription.get_drift_client_subscriber(
self.program, perp_market_indexes, spot_market_indexes, oracle_infos
)
if self.account_subscriber is None:
raise ValueError("No account subscriber found")

self.account_subscription_config = account_subscription

self.market_lookup_table = (
market_lookup_table
if market_lookup_table is not None
else configs[env].market_lookup_table
)
self.market_lookup_table = None
if env is not None:
self.market_lookup_table = (
market_lookup_table
if market_lookup_table is not None
else configs[env].market_lookup_table
)
self.market_lookup_table_account: Optional[AddressLookupTableAccount] = None

if tx_params is None:
Expand All @@ -167,10 +209,10 @@ def __init__(

self.enforce_tx_sequencing = enforce_tx_sequencing
if self.enforce_tx_sequencing is True:
file = Path(str(driftpy.__path__[0]) + "/idl/sequence_enforcer.json")
with file.open() as f:
raw = file.read_text()
idl = Idl.from_json(raw)
file = Path(
str(next(iter(driftpy.__path__))) + "/idl/sequence_enforcer.json"
)
idl = Idl.from_json(file.read_text())

provider = Provider(connection, wallet, opts)
self.sequence_enforcer_pid = (
Expand Down Expand Up @@ -208,7 +250,7 @@ def __init__(
blockhash_commitment=(
tx_sender_blockhash_commitment
if tx_sender_blockhash_commitment is not None
else "finalized"
else Commitment("finalized")
),
)
if tx_sender is None
Expand Down Expand Up @@ -3442,3 +3484,38 @@ async def load_sequence_info(self):
self.sequence_bump_by_subaccount[subaccount] = bump
self.sequence_initialized_by_subaccount[subaccount] = True
self.sequence_address_by_subaccount[subaccount] = address

async def update_user_protected_maker_orders(
self,
sub_account_id: int,
protected_orders: bool,
):
return (
await self.send_ixs(
[
await self.get_update_user_protected_maker_orders_ix(
sub_account_id, protected_orders
)
]
)
).tx_sig

async def get_update_user_protected_maker_orders_ix(
self,
sub_account_id: int,
protected_orders: bool,
):
return self.program.instruction["update_user_protected_maker_orders"](
sub_account_id,
protected_orders,
ctx=Context(
accounts={
"state": self.get_state_public_key(),
"user": self.get_user_account_public_key(sub_account_id),
"authority": self.wallet.payer.pubkey(),
"protected_maker_mode_config": get_protected_maker_mode_config_public_key(
self.program_id
),
}
),
)

0 comments on commit b34d51b

Please sign in to comment.