diff --git a/crates/drift-ffi-sys b/crates/drift-ffi-sys index d167a0d..b7c993e 160000 --- a/crates/drift-ffi-sys +++ b/crates/drift-ffi-sys @@ -1 +1 @@ -Subproject commit d167a0dbc0d4c296c6d2fc44155c8b652292df16 +Subproject commit b7c993edbf2259abddf4dd7e8b12576e0d4efeb2 diff --git a/crates/src/ffi.rs b/crates/src/ffi.rs index fd8383c..f5d4978 100644 --- a/crates/src/ffi.rs +++ b/crates/src/ffi.rs @@ -160,25 +160,25 @@ pub fn calculate_margin_requirement_and_total_collateral_and_liability_info( impl types::SpotPosition { pub fn is_available(&self) -> bool { - unsafe { spot_position_is_available(&self) } + unsafe { spot_position_is_available(self) } } pub fn get_signed_token_amount(&self, market: &accounts::SpotMarket) -> SdkResult { - to_sdk_result(unsafe { spot_position_get_signed_token_amount(&self, market) }) + to_sdk_result(unsafe { spot_position_get_signed_token_amount(self, market) }) } pub fn get_token_amount(&self, market: &accounts::SpotMarket) -> SdkResult { - to_sdk_result(unsafe { spot_position_get_token_amount(&self, market) }) + to_sdk_result(unsafe { spot_position_get_token_amount(self, market) }) } } impl types::PerpPosition { pub fn get_unrealized_pnl(&self, oracle_price: i64) -> SdkResult { - to_sdk_result(unsafe { perp_position_get_unrealized_pnl(&self, oracle_price) }) + to_sdk_result(unsafe { perp_position_get_unrealized_pnl(self, oracle_price) }) } pub fn is_available(&self) -> bool { - unsafe { perp_position_is_available(&self) } + unsafe { perp_position_is_available(self) } } pub fn is_open_position(&self) -> bool { - unsafe { perp_position_is_open_position(&self) } + unsafe { perp_position_is_open_position(self) } } pub fn worst_case_base_asset_amount( &self, @@ -186,7 +186,7 @@ impl types::PerpPosition { contract_type: ContractType, ) -> SdkResult { to_sdk_result(unsafe { - perp_position_worst_case_base_asset_amount(&self, oracle_price, contract_type) + perp_position_worst_case_base_asset_amount(self, oracle_price, contract_type) }) } pub fn simulate_settled_lp_position( @@ -195,7 +195,7 @@ impl types::PerpPosition { oracle_price: i64, ) -> SdkResult { to_sdk_result(unsafe { - perp_position_simulate_settled_lp_position(&self, market, oracle_price) + perp_position_simulate_settled_lp_position(self, market, oracle_price) }) } } @@ -203,19 +203,19 @@ impl types::PerpPosition { impl accounts::User { pub fn get_spot_position(&self, market_index: u16) -> SdkResult { // TODO: no clone - to_sdk_result(unsafe { user_get_spot_position(&self, market_index) }).map(|p| *p) + to_sdk_result(unsafe { user_get_spot_position(self, market_index) }).copied() } pub fn get_perp_position(&self, market_index: u16) -> SdkResult { - to_sdk_result(unsafe { user_get_perp_position(&self, market_index) }).map(|p| *p) + to_sdk_result(unsafe { user_get_perp_position(self, market_index) }).copied() } } impl types::Order { pub fn is_limit_order(&self) -> bool { - unsafe { order_is_limit_order(&self) } + unsafe { order_is_limit_order(self) } } pub fn is_resting_limit_order(&self, slot: Slot) -> SdkResult { - to_sdk_result(unsafe { order_is_resting_limit_order(&self, slot) }) + to_sdk_result(unsafe { order_is_resting_limit_order(self, slot) }) } } @@ -227,7 +227,7 @@ impl accounts::SpotMarket { margin_requirement_type: MarginRequirementType, ) -> SdkResult { to_sdk_result(unsafe { - spot_market_get_asset_weight(&self, size, oracle_price, margin_requirement_type) + spot_market_get_asset_weight(self, size, oracle_price, margin_requirement_type) }) } pub fn get_liability_weight( @@ -236,7 +236,7 @@ impl accounts::SpotMarket { margin_requirement_type: MarginRequirementType, ) -> SdkResult { to_sdk_result(unsafe { - spot_market_get_liability_weight(&self, size, margin_requirement_type) + spot_market_get_liability_weight(self, size, margin_requirement_type) }) } } @@ -247,10 +247,10 @@ impl accounts::PerpMarket { size: u128, margin_requirement_type: MarginRequirementType, ) -> SdkResult { - to_sdk_result(unsafe { perp_market_get_margin_ratio(&self, size, margin_requirement_type) }) + to_sdk_result(unsafe { perp_market_get_margin_ratio(self, size, margin_requirement_type) }) } pub fn get_open_interest(&self) -> u128 { - unsafe { perp_market_get_open_interest(&self) } + unsafe { perp_market_get_open_interest(self) } } } diff --git a/crates/src/jit_client.rs b/crates/src/jit_client.rs index d7274b9..1f2204e 100644 --- a/crates/src/jit_client.rs +++ b/crates/src/jit_client.rs @@ -152,8 +152,8 @@ impl JitProxyClient { drift_program: constants::PROGRAM_ID, }, &[¶ms.taker, account_data], - &[], - writable_markets.as_slice(), + [].iter(), + writable_markets.iter(), ); if let Some(referrer_info) = params.referrer_info { diff --git a/crates/src/lib.rs b/crates/src/lib.rs index d4a698b..4575d60 100644 --- a/crates/src/lib.rs +++ b/crates/src/lib.rs @@ -598,7 +598,7 @@ impl DriftClientBackend { self.perp_market_map.subscribe(), self.spot_market_map.subscribe(), self.oracle_map.subscribe(), - self.account_map.subscribe_account(&state_account()), + self.account_map.subscribe_account(state_account()), )?; Ok(()) @@ -609,7 +609,7 @@ impl DriftClientBackend { self.blockhash_subscriber.unsubscribe(); self.perp_market_map.unsubscribe()?; self.spot_market_map.unsubscribe()?; - self.account_map.unsubscribe_account(&state_account()); + self.account_map.unsubscribe_account(state_account()); self.oracle_map.unsubscribe().await } @@ -838,12 +838,37 @@ impl DriftClientBackend { } } +/// Configure markets as forced for inclusion by `TransactionBuilder` +/// +/// In contrast, without this Transactions are built using the latest known state of +/// users's open positions and orders, which can result in race conditions when executed onchain. +#[derive(Default)] +struct ForceMarkets { + /// markets must include as readable + readable: Vec, + /// markets must include as writeable + writeable: Vec, +} + +impl ForceMarkets { + /// Set given `markets` as readable, enforcing there inclusion in a final Tx + pub fn with_readable(&mut self, markets: &[MarketId]) -> &mut Self { + self.readable = markets.to_vec(); + self + } + /// Set given `markets` as writeable, enforcing there inclusion in a final Tx + pub fn with_writeable(&mut self, markets: &[MarketId]) -> &mut Self { + self.writeable = markets.to_vec(); + self + } +} + /// Composable Tx builder for Drift program /// /// Prefer `DriftClient::init_tx` /// /// ```ignore -/// use drift_sdk::{types::Context, TransactionBuilder, Wallet}; +/// use drift_rs::{types::Context, TransactionBuilder, Wallet}; /// /// let wallet = Wallet::from_seed_bs58(Context::Dev, "seed"); /// let client = DriftClient::new("api.example.com").await.unwrap(); @@ -876,6 +901,8 @@ pub struct TransactionBuilder<'a> { legacy: bool, /// add additional lookup tables (v0 only) lookup_tables: Vec, + /// some markets forced to include in the tx accounts list + force_markets: ForceMarkets, } impl<'a> TransactionBuilder<'a> { @@ -906,8 +933,14 @@ impl<'a> TransactionBuilder<'a> { ixs: Default::default(), lookup_tables: vec![program_data.lookup_table.clone()], legacy: false, + force_markets: Default::default(), } } + /// force given `markets` to be included in the final tx accounts list (ensure to call before building ixs) + pub fn force_include_markets(&mut self, readable: &[MarketId], writeable: &[MarketId]) { + self.force_markets.with_readable(readable); + self.force_markets.with_writeable(writeable); + } /// Use legacy tx mode pub fn legacy(mut self) -> Self { self.legacy = true; @@ -955,8 +988,8 @@ impl<'a> TransactionBuilder<'a> { token_program: constants::TOKEN_PROGRAM_ID, }, &[self.account_data.as_ref()], - &[], - &[MarketId::spot(spot_market_index)], + self.force_markets.readable.iter(), + [MarketId::spot(spot_market_index)].iter(), ); let ix = Instruction { @@ -994,8 +1027,10 @@ impl<'a> TransactionBuilder<'a> { token_program: constants::TOKEN_PROGRAM_ID, }, &[self.account_data.as_ref()], - &[], - &[MarketId::spot(spot_market_index)], + self.force_markets.readable.iter(), + [MarketId::spot(spot_market_index)] + .iter() + .chain(self.force_markets.writeable.iter()), ); let ix = Instruction { @@ -1015,10 +1050,11 @@ impl<'a> TransactionBuilder<'a> { /// Place new orders for account pub fn place_orders(mut self, orders: Vec) -> Self { - let readable_accounts: Vec = orders + let mut readable_accounts: Vec = orders .iter() .map(|o| (o.market_index, o.market_type).into()) .collect(); + readable_accounts.extend(&self.force_markets.readable); let accounts = build_accounts( self.program_data, @@ -1028,8 +1064,8 @@ impl<'a> TransactionBuilder<'a> { user: self.sub_account, }, &[self.account_data.as_ref()], - readable_accounts.as_ref(), - &[], + readable_accounts.iter(), + self.force_markets.writeable.iter(), ); let ix = Instruction { @@ -1053,8 +1089,8 @@ impl<'a> TransactionBuilder<'a> { user: self.sub_account, }, &[self.account_data.as_ref()], - &[], - &[], + self.force_markets.readable.iter(), + self.force_markets.writeable.iter(), ); let ix = Instruction { @@ -1090,8 +1126,10 @@ impl<'a> TransactionBuilder<'a> { user: self.sub_account, }, &[self.account_data.as_ref()], - &[(idx, kind).into()], - &[], + [(idx, kind).into()] + .iter() + .chain(self.force_markets.readable.iter()), + self.force_markets.writeable.iter(), ); let ix = Instruction { @@ -1118,8 +1156,8 @@ impl<'a> TransactionBuilder<'a> { user: self.sub_account, }, &[self.account_data.as_ref()], - &[], - &[], + self.force_markets.readable.iter(), + self.force_markets.writeable.iter(), ); let ix = Instruction { @@ -1142,8 +1180,8 @@ impl<'a> TransactionBuilder<'a> { user: self.sub_account, }, &[self.account_data.as_ref()], - &[], - &[], + self.force_markets.readable.iter(), + self.force_markets.writeable.iter(), ); for user_order_id in user_order_ids { @@ -1162,22 +1200,22 @@ impl<'a> TransactionBuilder<'a> { /// Modify existing order(s) by order id pub fn modify_orders(mut self, orders: &[(u32, ModifyOrderParams)]) -> Self { - for (order_id, params) in orders { - let accounts = build_accounts( - self.program_data, - types::accounts::PlaceOrders { - state: *state_account(), - authority: self.authority, - user: self.sub_account, - }, - &[self.account_data.as_ref()], - &[], - &[], - ); + let accounts = build_accounts( + self.program_data, + types::accounts::ModifyOrder { + state: *state_account(), + authority: self.authority, + user: self.sub_account, + }, + &[self.account_data.as_ref()], + self.force_markets.readable.iter(), + self.force_markets.writeable.iter(), + ); + for (order_id, params) in orders { let ix = Instruction { program_id: constants::PROGRAM_ID, - accounts, + accounts: accounts.clone(), data: InstructionData::data(&drift_idl::instructions::ModifyOrder { order_id: Some(*order_id), modify_order_params: *params, @@ -1191,22 +1229,22 @@ impl<'a> TransactionBuilder<'a> { /// Modify existing order(s) by user order id pub fn modify_orders_by_user_id(mut self, orders: &[(u8, ModifyOrderParams)]) -> Self { - for (user_order_id, params) in orders { - let accounts = build_accounts( - self.program_data, - types::accounts::PlaceOrders { - state: *state_account(), - authority: self.authority, - user: self.sub_account, - }, - &[self.account_data.as_ref()], - &[], - &[], - ); + let accounts = build_accounts( + self.program_data, + types::accounts::PlaceOrders { + state: *state_account(), + authority: self.authority, + user: self.sub_account, + }, + &[self.account_data.as_ref()], + self.force_markets.readable.iter(), + self.force_markets.writeable.iter(), + ); + for (user_order_id, params) in orders { let ix = Instruction { program_id: constants::PROGRAM_ID, - accounts, + accounts: accounts.clone(), data: InstructionData::data(&drift_idl::instructions::ModifyOrderByUserId { user_order_id: *user_order_id, modify_order_params: *params, @@ -1248,12 +1286,13 @@ impl<'a> TransactionBuilder<'a> { taker_stats: Wallet::derive_stats_account(taker), }, &[self.account_data.as_ref(), taker_account], - &[], + self.force_markets.readable.iter(), if is_perp { - &perp_writable + perp_writable.iter() } else { - &spot_writable - }, + spot_writable.iter() + } + .chain(self.force_markets.writeable.iter()), ); if let Some(referrer) = referrer { @@ -1323,12 +1362,13 @@ impl<'a> TransactionBuilder<'a> { user_stats: Wallet::derive_stats_account(&self.authority), }, user_accounts.as_slice(), - &[], + self.force_markets.readable.iter(), if is_perp { - &perp_writable + perp_writable.iter() } else { - &spot_writable - }, + spot_writable.iter() + } + .chain(self.force_markets.writeable.iter()), ); if referrer.is_some_and(|r| !maker_info.is_some_and(|(m, _)| m == r)) { @@ -1403,12 +1443,12 @@ impl<'a> TransactionBuilder<'a> { /// /// # Panics /// if the user has positions in an unknown market (i.e unsupported by the SDK) -pub fn build_accounts( +pub fn build_accounts<'a>( program_data: &ProgramData, base_accounts: impl ToAccountMetas, users: &[&User], - markets_readable: &[MarketId], - markets_writable: &[MarketId], + markets_readable: impl Iterator, + markets_writable: impl Iterator, ) -> Vec { // the order of accounts returned must be instruction, oracles, spot, perps see (https://github.com/drift-labs/protocol-v2/blob/master/programs/drift/src/instructions/optional_accounts.rs#L28) let mut seen = [0_u64; 2]; // [spot, perp] diff --git a/crates/src/math/liquidation.rs b/crates/src/math/liquidation.rs index a3d0bb3..6eda4c7 100644 --- a/crates/src/math/liquidation.rs +++ b/crates/src/math/liquidation.rs @@ -47,7 +47,7 @@ pub fn calculate_liquidation_price_and_unrealized_pnl( .get_perp_position(market_index) .map_err(|_| SdkError::NoPosiiton(market_index))?; - let unrealized_pnl = calculate_unrealized_pnl_inner(&position.into(), oracle.data.price)?; + let unrealized_pnl = calculate_unrealized_pnl_inner(&position, oracle.data.price)?; // matching spot market e.g. sol-perp => SOL spot let mut builder = AccountsListBuilder::default(); @@ -82,7 +82,7 @@ pub fn calculate_unrealized_pnl( .get_oracle_price_data_and_slot_for_perp_market(market_index) .map(|x| x.data.price) .unwrap_or(0); - calculate_unrealized_pnl_inner(&position.into(), oracle_price) + calculate_unrealized_pnl_inner(&position, oracle_price) } else { Err(SdkError::NoPosiiton(market_index)) } diff --git a/crates/src/types.rs b/crates/src/types.rs index 0cc80ac..4e02d8f 100644 --- a/crates/src/types.rs +++ b/crates/src/types.rs @@ -129,9 +129,13 @@ impl MarketId { pub fn kind(&self) -> MarketType { self.kind } + /// Convert self into its parts pub fn to_parts(self) -> (u16, MarketType) { (self.index, self.kind) } + pub fn is_perp(self) -> bool { + self.kind == MarketType::Perp + } } impl From<(u16, MarketType)> for MarketId {