Skip to content

Commit

Permalink
feat: custom account decoding (#51)
Browse files Browse the repository at this point in the history
* feat: custom user account decoding

* fix: add decode_user to ws usermap sub

* fix: polling sub test failure

* fix: change OrderTriggerCondition naming

* fix: rework cached subscription for drift client to avoid 413/429

* chore: update README

* feat: revert cached dc sub, add DemoDriftClientAccountSubscriber

* update readme

* feat: add get_markets_and_oracles

* chore: add comment lol

* fix: get_markets_and_oracles no gPA

* chore: update perp market oracle configs to match spot market

* feat: more explicit demo subscriber

* chore: update readme with demo example

* chore: better readme

* chore: update perp_market & spot_market configs

* chore: update mainnet USDC SpotMarketConfig
  • Loading branch information
soundsonacid authored Dec 13, 2023
1 parent b45e6fc commit 0517467
Show file tree
Hide file tree
Showing 23 changed files with 1,084 additions and 51 deletions.
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,40 @@ pip install driftpy

Note: requires Python >= 3.10.

## ⚠️ IMPORTANT ⚠️

**PLEASE**, do not use QuickNode free RPCs to subscribe to the Drift Client.

If you are using QuickNode, you *must* use `AccountSubscriptionConfig("demo")`, and you can only subscribe to 1 perp market and 1 spot market at a time.

Non-QuickNode free RPCs (including the public mainnet-beta url) can use `cached` as well.

Example setup for `AccountSubscriptionConfig("demo")`:

```
# This example will listen to perp markets 0 & 1 and spot market 0
# If you are listening to any perp markets, you must listen to spot market 0 or the SDK will break
perp_markets = [0, 1]
spot_market_oracle_infos, perp_market_oracle_infos, spot_market_indexes = get_markets_and_oracles(perp_markets = perp_markets)
oracle_infos = spot_market_oracle_infos + perp_market_oracle_infos
drift_client = DriftClient(
connection,
wallet,
"mainnet",
perp_market_indexes = perp_markets,
spot_market_indexes = spot_market_indexes,
oracle_infos = oracle_infos,
account_subscription = AccountSubscriptionConfig("demo"),
)
await drift_client.subscribe()
```
If you intend to use `AccountSubscriptionConfig("demo)`, you *must* call `get_markets_and_oracles` to get the information you need.

`get_markets_and_oracles` will return all the necessary `OracleInfo`s and `market_indexes` in order to use the SDK.

## SDK Examples

- `examples/` folder includes more examples of how to use the SDK including how to provide liquidity/become an lp, stake in the insurance fund, etc.
Expand Down
2 changes: 1 addition & 1 deletion examples/floating_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def main(
post_only=PostOnlyParams.TryPostOnly(),
immediate_or_cancel=False,
trigger_price=0,
trigger_condition=OrderTriggerCondition.ABOVE(),
trigger_condition=OrderTriggerCondition.Above(),
oracle_price_offset=0,
auction_duration=None,
max_ts=None,
Expand Down
26 changes: 23 additions & 3 deletions src/driftpy/account_subscription_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@
WebsocketDriftClientAccountSubscriber,
WebsocketUserAccountSubscriber,
)
from driftpy.accounts.demo import (
DemoDriftClientAccountSubscriber,
DemoUserAccountSubscriber
)
from driftpy.types import OracleInfo


class AccountSubscriptionConfig:
@staticmethod
def default():
return AccountSubscriptionConfig("websocket")

def __init__(
self,
type: Literal["polling", "websocket", "cached"],
type: Literal["polling", "websocket", "cached", "demo"],
bulk_account_loader: Optional[BulkAccountLoader] = None,
commitment: Commitment = None,
):
Expand Down Expand Up @@ -83,7 +86,20 @@ def get_drift_client_subscriber(
self.commitment,
)
case "cached":
return CachedDriftClientAccountSubscriber(program, self.commitment)
return CachedDriftClientAccountSubscriber(
program,
self.commitment
)
case "demo":
if perp_market_indexes == [] or spot_market_indexes == [] or oracle_infos == []:
raise ValueError("spot_market_indexes / perp_market_indexes / oracle_infos all must be provided with demo config")
return DemoDriftClientAccountSubscriber(
program,
perp_market_indexes,
spot_market_indexes,
oracle_infos,
self.commitment
)

def get_user_client_subscriber(self, program: Program, user_pubkey: Pubkey):
match self.type:
Expand All @@ -99,3 +115,7 @@ def get_user_client_subscriber(self, program: Program, user_pubkey: Pubkey):
return CachedUserAccountSubscriber(
user_pubkey, program, self.commitment
)
case "demo":
return DemoUserAccountSubscriber(
user_pubkey, program, self.commitment
)
2 changes: 1 addition & 1 deletion src/driftpy/accounts/cache/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,4 @@ def get_oracle_price_data_and_slot(
return self.cache["oracle_price_data"][str(oracle)]

def unsubscribe(self):
self.cache = None
self.cache = None
2 changes: 2 additions & 0 deletions src/driftpy/accounts/demo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .drift_client import *
from .user import *
93 changes: 93 additions & 0 deletions src/driftpy/accounts/demo/drift_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import asyncio
from anchorpy import Program
from solders.pubkey import Pubkey
from solana.rpc.commitment import Commitment

from driftpy.accounts import (
get_state_account_and_slot,
)
from driftpy.accounts.get_accounts import get_spot_market_account_and_slot, get_perp_market_account_and_slot
from driftpy.accounts.oracle import oracle_ai_to_oracle_price_data
from driftpy.constants.perp_markets import devnet_perp_market_configs, mainnet_perp_market_configs
from driftpy.constants.spot_markets import devnet_spot_market_configs, mainnet_spot_market_configs
from driftpy.accounts.types import DriftClientAccountSubscriber, DataAndSlot
from typing import Optional
from driftpy.types import (
PerpMarketAccount,
SpotMarketAccount,
OraclePriceData,
StateAccount,
)


class DemoDriftClientAccountSubscriber(DriftClientAccountSubscriber):
def __init__(self, program: Program, perp_market_indexes, spot_market_indexes, oracle_infos, commitment: Commitment = "confirmed"):
self.program = program
self.perp_market_indexes = perp_market_indexes
self.spot_market_indexes = spot_market_indexes
self.oracle_infos = oracle_infos
self.commitment = commitment
self.cache = None

async def subscribe(self):
await self.update_cache()

async def update_cache(self):
if self.cache is None:
self.cache = {}

state_and_slot = await get_state_account_and_slot(self.program)
self.cache["state"] = state_and_slot

oracle_data: dict[str, DataAndSlot[OraclePriceData]] = {}
spot_markets: list[DataAndSlot[SpotMarketAccount]] = []
perp_markets: list[DataAndSlot[PerpMarketAccount]] = []

spot_market_indexes = sorted(self.spot_market_indexes)
perp_market_indexes = sorted(self.perp_market_indexes)

for index in spot_market_indexes:
spot_market_and_slot = await get_spot_market_account_and_slot(self.program, index)
spot_markets.append(spot_market_and_slot)

for index in perp_market_indexes:
perp_market_and_slot = await get_perp_market_account_and_slot(self.program, index)
perp_markets.append(perp_market_and_slot)

oracle_pubkeys = [oracle.pubkey for oracle in self.oracle_infos]

oracle_accounts = await self.program.provider.connection.get_multiple_accounts(oracle_pubkeys)

for i, oracle_ai in enumerate(oracle_accounts.value):
if oracle_ai.owner == Pubkey.from_string("NativeLoader1111111111111111111111111111111"):
continue
oracle_price_data_and_slot = oracle_ai_to_oracle_price_data(oracle_ai, self.oracle_infos[i].source)
oracle_data[str(self.oracle_infos[i].pubkey)] = oracle_price_data_and_slot

self.cache["spot_markets"] = spot_markets
self.cache["perp_markets"] = perp_markets
self.cache["oracle_price_data"] = oracle_data

async def fetch(self):
await self.update_cache()

def get_state_account_and_slot(self) -> Optional[DataAndSlot[StateAccount]]:
return self.cache["state"]

def get_perp_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[PerpMarketAccount]]:
return self.cache["perp_markets"][market_index]

def get_spot_market_and_slot(
self, market_index: int
) -> Optional[DataAndSlot[SpotMarketAccount]]:
return self.cache["spot_markets"][market_index]

def get_oracle_price_data_and_slot(
self, oracle: Pubkey
) -> Optional[DataAndSlot[OraclePriceData]]:
return self.cache["oracle_price_data"][str(oracle)]

def unsubscribe(self):
self.cache = None
38 changes: 38 additions & 0 deletions src/driftpy/accounts/demo/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Optional

from anchorpy import Program
from solders.pubkey import Pubkey
from solana.rpc.commitment import Commitment

from driftpy.accounts import get_user_account_and_slot
from driftpy.accounts import UserAccountSubscriber, DataAndSlot
from driftpy.types import UserAccount


class DemoUserAccountSubscriber(UserAccountSubscriber):
def __init__(
self,
user_pubkey: Pubkey,
program: Program,
commitment: Commitment = "confirmed",
):
self.program = program
self.commitment = commitment
self.user_pubkey = user_pubkey
self.user_and_slot = None

async def subscribe(self):
await self.update_cache()

async def update_cache(self):
user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey)
self.user_and_slot = user_and_slot

async def fetch(self):
await self.update_cache()

def get_user_account_and_slot(self) -> Optional[DataAndSlot[UserAccount]]:
return self.user_and_slot

def unsubscribe(self):
self.user_and_slot = None
12 changes: 12 additions & 0 deletions src/driftpy/accounts/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from driftpy.types import OracleSource, OraclePriceData, is_variant

from solders.pubkey import Pubkey
from solders.account import Account
from pythclient.pythaccounts import PythPriceInfo, _ACCOUNT_HEADER_BYTES, EmaType
from solana.rpc.async_api import AsyncClient
import struct
Expand Down Expand Up @@ -33,6 +34,17 @@ async def get_oracle_price_data_and_slot(
else:
raise NotImplementedError("Unsupported Oracle Source", str(oracle_source))

def oracle_ai_to_oracle_price_data(oracle_ai: Account, oracle_source=OracleSource.Pyth()) -> DataAndSlot[OraclePriceData]:
if "Pyth" in str(oracle_source):
oracle_price_data = decode_pyth_price_info(oracle_ai.data, oracle_source)

return DataAndSlot(oracle_price_data.slot, oracle_price_data)
elif is_variant(oracle_source, "QuoteAsset"):
return DataAndSlot(
data=OraclePriceData(PRICE_PRECISION, 0, 1, 1, 0, True), slot=0
)
else:
raise NotImplementedError("Unsupported Oracle Source", str(oracle_source))

def decode_pyth_price_info(
buffer: bytes,
Expand Down
1 change: 1 addition & 0 deletions src/driftpy/accounts/ws/account_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
self.decode = (
decode if decode is not None else self.program.coder.accounts.decode
)
self.ws = None

async def subscribe(self):
if self.data_and_slot is None:
Expand Down
1 change: 0 additions & 1 deletion src/driftpy/accounts/ws/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from driftpy.accounts.oracle import get_oracle_decode_fn


class WebsocketDriftClientAccountSubscriber(DriftClientAccountSubscriber):
def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/driftpy/accounts/ws/multi_account_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
decode if decode is not None else self.program.coder.accounts.decode
)
self.subscribed_accounts: Dict[Pubkey, DataAndSlot[T]] = {}
self.ws = None

async def subscribe(self):
self.task = asyncio.create_task(self.subscribe_ws())
Expand Down
42 changes: 39 additions & 3 deletions src/driftpy/constants/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, Optional, Sequence, Union

from driftpy.constants.spot_markets import (
devnet_spot_market_configs,
Expand All @@ -15,7 +15,7 @@

from anchorpy import Program

from driftpy.types import OracleInfo
from driftpy.types import OracleInfo, OracleSource, SpotMarketAccount

DriftEnv = Literal["devnet", "mainnet"]

Expand Down Expand Up @@ -69,7 +69,6 @@ class Config:
),
}


async def find_all_market_and_oracles(
program: Program,
) -> (list[int], list[int], list[OracleInfo]):
Expand All @@ -92,3 +91,40 @@ async def find_all_market_and_oracles(
oracle_infos[str(oracle)] = OracleInfo(oracle, oracle_source)

return perp_market_indexes, spot_market_indexes, oracle_infos.values()

def find_market_config_by_index(
market_configs: list[Union[SpotMarketConfig, PerpMarketConfig]],
market_index: int
) -> Optional[Union[SpotMarketConfig, PerpMarketConfig]]:
for config in market_configs:
if hasattr(config, 'market_index') and config.market_index == market_index:
return config
return None


def get_markets_and_oracles(
env: DriftEnv = "mainnet",
perp_markets: Optional[Sequence[int]] = None,
spot_markets: Optional[Sequence[int]] = None,
):
config = configs[env]
spot_market_oracle_infos = []
perp_market_oracle_infos = []
spot_market_indexes = []

if perp_markets is None and spot_markets is None:
raise ValueError("no indexes provided")

if spot_markets is not None:
for spot_market_index in spot_markets:
market_config = find_market_config_by_index(config.spot_markets, spot_market_index)
spot_market_oracle_infos.append(OracleInfo(market_config.oracle, market_config.oracle_source))

if perp_markets is not None:
spot_market_indexes.append(0)
spot_market_oracle_infos.append(OracleInfo(config.spot_markets[0].oracle, config.spot_markets[0].oracle_source))
for perp_market_index in perp_markets:
market_config = find_market_config_by_index(config.perp_markets, perp_market_index)
perp_market_oracle_infos.append(OracleInfo(market_config.oracle, market_config.oracle_source))

return spot_market_oracle_infos, perp_market_oracle_infos, spot_market_indexes
Loading

0 comments on commit 0517467

Please sign in to comment.