Skip to content

Commit

Permalink
add place_orders
Browse files Browse the repository at this point in the history
  • Loading branch information
crispheaney committed Nov 27, 2023
1 parent 8ead727 commit 34300b6
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 135 deletions.
58 changes: 16 additions & 42 deletions examples/limit_order_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,62 +202,36 @@ async def main(
len(ask_prices)
)

default_order_params = OrderParams(
order_type=OrderType.LIMIT(),
market_type=market_type,
direction=PositionDirection.LONG(),
user_order_id=0,
base_asset_amount=0,
price=0,
market_index=market_index,
reduce_only=False,
post_only=PostOnlyParams.TRY_POST_ONLY()
if not taker
else PostOnlyParams.NONE(),
immediate_or_cancel=False,
trigger_price=0,
trigger_condition=OrderTriggerCondition.ABOVE(),
oracle_price_offset=0,
auction_duration=None,
max_ts=None,
auction_start_price=None,
auction_end_price=None,
)
order_params = []
for x in bid_prices:
bid_order_params = copy.deepcopy(default_order_params)
bid_order_params.direction = PositionDirection.LONG()
bid_order_params.base_asset_amount = int(
base_asset_amount_per_bid * BASE_PRECISION
bid_order_params = OrderParams(
order_type=OrderType.LIMIT(),
market_index=market_index,
market_type=market_type,
direction=PositionDirection.LONG(),
base_asset_amount=int(base_asset_amount_per_bid * BASE_PRECISION),
price=int(x * PRICE_PRECISION),
)
bid_order_params.price = int(x * PRICE_PRECISION)
if bid_order_params.base_asset_amount > 0:
order_params.append(bid_order_params)

for x in ask_prices:
ask_order_params = copy.deepcopy(default_order_params)
ask_order_params.base_asset_amount = int(
base_asset_amount_per_ask * BASE_PRECISION
ask_order_params = OrderParams(
order_type=OrderType.LIMIT(),
market_index=market_index,
market_type=market_type,
direction=PositionDirection.SHORT(),
base_asset_amount=int(base_asset_amount_per_ask * BASE_PRECISION),
price=int(x * PRICE_PRECISION),
)
ask_order_params.direction = PositionDirection.SHORT()
ask_order_params.price = int(x * PRICE_PRECISION)
if ask_order_params.base_asset_amount > 0:
order_params.append(ask_order_params)
# print(order_params)
# order_print([bid_order_params, ask_order_params], market_name)

perp_orders_ix = []
spot_orders_ix = []
if is_perp:
perp_orders_ix = await drift_acct.get_place_perp_orders_ix(
order_params, subaccount_id
)
else:
spot_orders_ix = await drift_acct.get_place_spot_orders_ix(
order_params, subaccount_id
)
place_orders_ix = drift_acct.get_place_orders_ix(order_params)
# perp_orders_ix = [ await drift_acct.get_place_perp_order_ix(order_params[0], subaccount_id)]
await drift_acct.send_ixs(perp_orders_ix + spot_orders_ix)
await drift_acct.send_ixs([place_orders_ix])


if __name__ == "__main__":
Expand Down
145 changes: 53 additions & 92 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,59 @@ def get_place_perp_order_ix(

return ix

async def place_orders(
self,
order_params: List[OrderParams],
sub_account_id: int = 0,
):
return await self.send_ixs(
[
self.get_place_orders_ix(order_params, sub_account_id),
]
)

def get_place_orders_ix(
self,
order_params: List[OrderParams],
sub_account_id: int = 0,
):
user_account_public_key = self.get_user_account_public_key(sub_account_id)
user_stats_public_key = self.get_user_stats_public_key()

readable_perp_market_indexes = []
readable_spot_market_indexes = []
for order_param in order_params:
order_param.check_market_type()

if "PERP" in str(order_param.market_type):
readable_perp_market_indexes.append(order_param.market_index)
else:
if len(readable_spot_market_indexes) == 0:
readable_spot_market_indexes.append(QUOTE_SPOT_MARKET_INDEX)

readable_spot_market_indexes.append(order_param.market_index)

remaining_accounts = self.get_remaining_accounts(
readable_perp_market_indexes=readable_perp_market_indexes,
readable_spot_market_indexes=readable_spot_market_indexes,
user_accounts=[self.get_user_account(sub_account_id)],
)

ix = self.program.instruction["place_orders"](
order_params,
ctx=Context(
accounts={
"state": self.get_state_public_key(),
"user": user_account_public_key,
"userStats": user_stats_public_key,
"authority": self.wallet.public_key,
},
remaining_accounts=remaining_accounts,
),
)

return ix

async def cancel_order(
self,
order_id: Optional[int] = None,
Expand Down Expand Up @@ -786,98 +839,6 @@ def get_cancel_orders_ix(
),
)

def get_place_spot_orders_ix(
self,
order_params: List[OrderParams],
sub_account_id: int = 0,
):
user_account_public_key = self.get_user_account_public_key(sub_account_id)

remaining_accounts = self.get_remaining_accounts(
readable_spot_market_indexes=[
QUOTE_SPOT_MARKET_INDEX,
order_params.market_index,
],
user_accounts=[self.get_user_account(sub_account_id)],
)

ixs = [
self.program.instruction["cancel_orders"](
None,
None,
None,
ctx=Context(
accounts={
"state": self.get_state_public_key(),
"user": self.get_user_account_public_key(sub_account_id),
"authority": self.wallet.public_key,
},
remaining_accounts=remaining_accounts,
),
)
]
for order_param in order_params:
ix = self.program.instruction["place_spot_order"](
order_param,
ctx=Context(
accounts={
"state": self.get_state_public_key(),
"user": user_account_public_key,
"authority": self.wallet.public_key,
},
remaining_accounts=remaining_accounts,
),
)
ixs.append(ix)

return ixs

async def get_place_perp_orders_ix(
self, order_params: List[OrderParams], sub_account_id: int = 0, cancel_all=True
):
[order_param.set_perp() for order_param in order_params]

user_account_public_key = self.get_user_account_public_key(sub_account_id)

readable_market_indexes = list(set([x.market_index for x in order_params]))
remaining_accounts = self.get_remaining_accounts(
readable_perp_market_indexes=readable_market_indexes,
user_accounts=[self.get_user_account(sub_account_id)],
)

ixs = []
if cancel_all:
ixs.append(
self.program.instruction["cancel_orders"](
None,
None,
None,
ctx=Context(
accounts={
"state": self.get_state_public_key(),
"user": self.get_user_account_public_key(sub_account_id),
"authority": self.wallet.public_key,
},
remaining_accounts=remaining_accounts,
),
)
)
for order_param in order_params:
ix = self.program.instruction["place_perp_order"](
order_param,
ctx=Context(
accounts={
"state": self.get_state_public_key(),
"user": user_account_public_key,
"authority": self.wallet.public_key,
},
remaining_accounts=remaining_accounts,
),
)
ixs.append(ix)

return ixs

async def place_and_take_perp_order(
self,
order_params: OrderParams,
Expand Down
4 changes: 3 additions & 1 deletion tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
PerpMarketAccount,
OrderType,
OrderParams,
MarketType,
# SwapDirection,
)
from driftpy.accounts import (
Expand Down Expand Up @@ -236,13 +237,14 @@ async def test_open_orders(
assert open_orders == user_account.orders

order_params = OrderParams(
market_type=MarketType.PERP(),
order_type=OrderType.MARKET(),
market_index=0,
base_asset_amount=int(1 * BASE_PRECISION),
direction=PositionDirection.LONG(),
user_order_id=169,
)
ixs = await drift_client.get_place_perp_orders_ix([order_params])
ixs = drift_client.get_place_orders_ix([order_params])
await drift_client.send_ixs(ixs)
await drift_user.account_subscriber.update_cache()
open_orders_after = drift_user.get_open_orders()
Expand Down

0 comments on commit 34300b6

Please sign in to comment.