From 8257c77289d74fdd3c22ea4e9ab7bc0a166cf871 Mon Sep 17 00:00:00 2001 From: jordy25519 Date: Mon, 3 Jun 2024 12:22:55 +0800 Subject: [PATCH] add option to force include market accounts in TXBuilder --- src/lib.rs | 85 ++++++++++++++++++++++++++++++++++------------------ src/types.rs | 4 +++ 2 files changed, 60 insertions(+), 29 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 485aa2b..87b9c04 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1029,6 +1029,24 @@ impl DriftClientBackend { } } +/// Markets forced to include by `TransactionBuilder` +#[derive(Default)] +struct ForceMarkets { + readable: Vec, + writeable: Vec, +} + +impl ForceMarkets { + /// Add a market to the forced markets in r/w mode + pub fn add_markets(&mut self, markets: &[MarketId], write: bool) { + if write { + self.writeable = markets.to_vec(); + } else { + self.readable = markets.to_vec(); + } + } +} + /// Composable Tx builder for Drift program /// /// Prefer `DriftClient::init_tx` @@ -1067,6 +1085,8 @@ pub struct TransactionBuilder<'a> { legacy: bool, /// add additional lookup tables (v0 only) lookup_tables: Vec, + /// some markets forced to include in the tx accounts list + force_markets: ForceMarkets, } impl<'a> TransactionBuilder<'a> { @@ -1097,8 +1117,14 @@ impl<'a> TransactionBuilder<'a> { ixs: Default::default(), lookup_tables: vec![program_data.lookup_table.clone()], legacy: false, + force_markets: Default::default(), } } + /// force given `markets` to be included in the final tx accounts list (ensure to call before building ixs) + pub fn force_include_markets(&mut self, readable: &[MarketId], writeable: &[MarketId]) { + self.force_markets.add_markets(readable, false); + self.force_markets.add_markets(writeable, true); + } /// Use legacy tx mode pub fn legacy(mut self) -> Self { self.legacy = true; @@ -1206,10 +1232,11 @@ impl<'a> TransactionBuilder<'a> { /// Place new orders for account pub fn place_orders(mut self, orders: Vec) -> Self { - let readable_accounts: Vec = orders + let mut readable_accounts: Vec = orders .iter() .map(|o| (o.market_index, o.market_type).into()) .collect(); + readable_accounts.extend(&self.force_markets.readable); let accounts = build_accounts( self.program_data, @@ -1353,22 +1380,22 @@ impl<'a> TransactionBuilder<'a> { /// Modify existing order(s) by order id pub fn modify_orders(mut self, orders: &[(u32, ModifyOrderParams)]) -> Self { - for (order_id, params) in orders { - let accounts = build_accounts( - self.program_data, - drift::accounts::PlaceOrder { - state: *state_account(), - authority: self.authority, - user: self.sub_account, - }, - &[self.account_data.as_ref()], - &[], - &[], - ); + let accounts = build_accounts( + self.program_data, + drift::accounts::PlaceOrder { + state: *state_account(), + authority: self.authority, + user: self.sub_account, + }, + &[self.account_data.as_ref()], + self.force_markets.readable.as_slice(), + &[], + ); + for (order_id, params) in orders { let ix = Instruction { program_id: constants::PROGRAM_ID, - accounts, + accounts: accounts.clone(), data: InstructionData::data(&drift::instruction::ModifyOrder { order_id: Some(*order_id), modify_order_params: params.clone(), @@ -1382,22 +1409,22 @@ impl<'a> TransactionBuilder<'a> { /// Modify existing order(s) by user order id pub fn modify_orders_by_user_id(mut self, orders: &[(u8, ModifyOrderParams)]) -> Self { - for (user_order_id, params) in orders { - let accounts = build_accounts( - self.program_data, - drift::accounts::PlaceOrder { - state: *state_account(), - authority: self.authority, - user: self.sub_account, - }, - &[self.account_data.as_ref()], - &[], - &[], - ); + let accounts = build_accounts( + self.program_data, + drift::accounts::PlaceOrder { + state: *state_account(), + authority: self.authority, + user: self.sub_account, + }, + &[self.account_data.as_ref()], + self.force_markets.readable.as_slice(), + &[], + ); + for (user_order_id, params) in orders { let ix = Instruction { program_id: constants::PROGRAM_ID, - accounts, + accounts: accounts.clone(), data: InstructionData::data(&drift::instruction::ModifyOrderByUserId { user_order_id: *user_order_id, modify_order_params: params.clone(), @@ -1439,7 +1466,7 @@ impl<'a> TransactionBuilder<'a> { taker_stats: Wallet::derive_stats_account(taker, &constants::PROGRAM_ID), }, &[self.account_data.as_ref(), &taker_account], - &[], + &self.force_markets.readable, if is_perp { &perp_writable } else { @@ -1514,7 +1541,7 @@ impl<'a> TransactionBuilder<'a> { user_stats: Wallet::derive_stats_account(&self.authority, &constants::PROGRAM_ID), }, user_accounts.as_slice(), - &[], + &self.force_markets.readable, if is_perp { &perp_writable } else { diff --git a/src/types.rs b/src/types.rs index 5899a62..ede600d 100644 --- a/src/types.rs +++ b/src/types.rs @@ -85,6 +85,10 @@ impl MarketId { index: 0, kind: MarketType::Spot, }; + /// Convert self into its parts + pub fn is_perp(self) -> bool { + self.kind == MarketType::Perp + } } impl From<(u16, MarketType)> for MarketId {