From 34300b660890e631a7a44ebad1b9c561fd3506a3 Mon Sep 17 00:00:00 2001 From: Chris Heaney Date: Mon, 27 Nov 2023 08:18:17 -0500 Subject: [PATCH] add place_orders --- examples/limit_order_grid.py | 58 ++++---------- src/driftpy/drift_client.py | 145 +++++++++++++---------------------- tests/test.py | 4 +- 3 files changed, 72 insertions(+), 135 deletions(-) diff --git a/examples/limit_order_grid.py b/examples/limit_order_grid.py index a13b79e1..0144369d 100644 --- a/examples/limit_order_grid.py +++ b/examples/limit_order_grid.py @@ -202,62 +202,36 @@ async def main( len(ask_prices) ) - default_order_params = OrderParams( - order_type=OrderType.LIMIT(), - market_type=market_type, - direction=PositionDirection.LONG(), - user_order_id=0, - base_asset_amount=0, - price=0, - market_index=market_index, - reduce_only=False, - post_only=PostOnlyParams.TRY_POST_ONLY() - if not taker - else PostOnlyParams.NONE(), - immediate_or_cancel=False, - trigger_price=0, - trigger_condition=OrderTriggerCondition.ABOVE(), - oracle_price_offset=0, - auction_duration=None, - max_ts=None, - auction_start_price=None, - auction_end_price=None, - ) order_params = [] for x in bid_prices: - bid_order_params = copy.deepcopy(default_order_params) - bid_order_params.direction = PositionDirection.LONG() - bid_order_params.base_asset_amount = int( - base_asset_amount_per_bid * BASE_PRECISION + bid_order_params = OrderParams( + order_type=OrderType.LIMIT(), + market_index=market_index, + market_type=market_type, + direction=PositionDirection.LONG(), + base_asset_amount=int(base_asset_amount_per_bid * BASE_PRECISION), + price=int(x * PRICE_PRECISION), ) - bid_order_params.price = int(x * PRICE_PRECISION) if bid_order_params.base_asset_amount > 0: order_params.append(bid_order_params) for x in ask_prices: - ask_order_params = copy.deepcopy(default_order_params) - ask_order_params.base_asset_amount = int( - base_asset_amount_per_ask * BASE_PRECISION + ask_order_params = OrderParams( + order_type=OrderType.LIMIT(), + market_index=market_index, + market_type=market_type, + direction=PositionDirection.SHORT(), + base_asset_amount=int(base_asset_amount_per_ask * BASE_PRECISION), + price=int(x * PRICE_PRECISION), ) - ask_order_params.direction = PositionDirection.SHORT() - ask_order_params.price = int(x * PRICE_PRECISION) if ask_order_params.base_asset_amount > 0: order_params.append(ask_order_params) # print(order_params) # order_print([bid_order_params, ask_order_params], market_name) - perp_orders_ix = [] - spot_orders_ix = [] - if is_perp: - perp_orders_ix = await drift_acct.get_place_perp_orders_ix( - order_params, subaccount_id - ) - else: - spot_orders_ix = await drift_acct.get_place_spot_orders_ix( - order_params, subaccount_id - ) + place_orders_ix = drift_acct.get_place_orders_ix(order_params) # perp_orders_ix = [ await drift_acct.get_place_perp_order_ix(order_params[0], subaccount_id)] - await drift_acct.send_ixs(perp_orders_ix + spot_orders_ix) + await drift_acct.send_ixs([place_orders_ix]) if __name__ == "__main__": diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index e073a0fc..99c98e31 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -700,6 +700,59 @@ def get_place_perp_order_ix( return ix + async def place_orders( + self, + order_params: List[OrderParams], + sub_account_id: int = 0, + ): + return await self.send_ixs( + [ + self.get_place_orders_ix(order_params, sub_account_id), + ] + ) + + def get_place_orders_ix( + self, + order_params: List[OrderParams], + sub_account_id: int = 0, + ): + user_account_public_key = self.get_user_account_public_key(sub_account_id) + user_stats_public_key = self.get_user_stats_public_key() + + readable_perp_market_indexes = [] + readable_spot_market_indexes = [] + for order_param in order_params: + order_param.check_market_type() + + if "PERP" in str(order_param.market_type): + readable_perp_market_indexes.append(order_param.market_index) + else: + if len(readable_spot_market_indexes) == 0: + readable_spot_market_indexes.append(QUOTE_SPOT_MARKET_INDEX) + + readable_spot_market_indexes.append(order_param.market_index) + + remaining_accounts = self.get_remaining_accounts( + readable_perp_market_indexes=readable_perp_market_indexes, + readable_spot_market_indexes=readable_spot_market_indexes, + user_accounts=[self.get_user_account(sub_account_id)], + ) + + ix = self.program.instruction["place_orders"]( + order_params, + ctx=Context( + accounts={ + "state": self.get_state_public_key(), + "user": user_account_public_key, + "userStats": user_stats_public_key, + "authority": self.wallet.public_key, + }, + remaining_accounts=remaining_accounts, + ), + ) + + return ix + async def cancel_order( self, order_id: Optional[int] = None, @@ -786,98 +839,6 @@ def get_cancel_orders_ix( ), ) - def get_place_spot_orders_ix( - self, - order_params: List[OrderParams], - sub_account_id: int = 0, - ): - user_account_public_key = self.get_user_account_public_key(sub_account_id) - - remaining_accounts = self.get_remaining_accounts( - readable_spot_market_indexes=[ - QUOTE_SPOT_MARKET_INDEX, - order_params.market_index, - ], - user_accounts=[self.get_user_account(sub_account_id)], - ) - - ixs = [ - self.program.instruction["cancel_orders"]( - None, - None, - None, - ctx=Context( - accounts={ - "state": self.get_state_public_key(), - "user": self.get_user_account_public_key(sub_account_id), - "authority": self.wallet.public_key, - }, - remaining_accounts=remaining_accounts, - ), - ) - ] - for order_param in order_params: - ix = self.program.instruction["place_spot_order"]( - order_param, - ctx=Context( - accounts={ - "state": self.get_state_public_key(), - "user": user_account_public_key, - "authority": self.wallet.public_key, - }, - remaining_accounts=remaining_accounts, - ), - ) - ixs.append(ix) - - return ixs - - async def get_place_perp_orders_ix( - self, order_params: List[OrderParams], sub_account_id: int = 0, cancel_all=True - ): - [order_param.set_perp() for order_param in order_params] - - user_account_public_key = self.get_user_account_public_key(sub_account_id) - - readable_market_indexes = list(set([x.market_index for x in order_params])) - remaining_accounts = self.get_remaining_accounts( - readable_perp_market_indexes=readable_market_indexes, - user_accounts=[self.get_user_account(sub_account_id)], - ) - - ixs = [] - if cancel_all: - ixs.append( - self.program.instruction["cancel_orders"]( - None, - None, - None, - ctx=Context( - accounts={ - "state": self.get_state_public_key(), - "user": self.get_user_account_public_key(sub_account_id), - "authority": self.wallet.public_key, - }, - remaining_accounts=remaining_accounts, - ), - ) - ) - for order_param in order_params: - ix = self.program.instruction["place_perp_order"]( - order_param, - ctx=Context( - accounts={ - "state": self.get_state_public_key(), - "user": user_account_public_key, - "authority": self.wallet.public_key, - }, - remaining_accounts=remaining_accounts, - ), - ) - ixs.append(ix) - - return ixs - async def place_and_take_perp_order( self, order_params: OrderParams, diff --git a/tests/test.py b/tests/test.py index d38f8835..c5f9327c 100644 --- a/tests/test.py +++ b/tests/test.py @@ -36,6 +36,7 @@ PerpMarketAccount, OrderType, OrderParams, + MarketType, # SwapDirection, ) from driftpy.accounts import ( @@ -236,13 +237,14 @@ async def test_open_orders( assert open_orders == user_account.orders order_params = OrderParams( + market_type=MarketType.PERP(), order_type=OrderType.MARKET(), market_index=0, base_asset_amount=int(1 * BASE_PRECISION), direction=PositionDirection.LONG(), user_order_id=169, ) - ixs = await drift_client.get_place_perp_orders_ix([order_params]) + ixs = drift_client.get_place_orders_ix([order_params]) await drift_client.send_ixs(ixs) await drift_user.account_subscriber.update_cache() open_orders_after = drift_user.get_open_orders()