diff --git a/Cargo.lock b/Cargo.lock index 6cf99ba..ddb4caa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1120,6 +1120,7 @@ dependencies = [ "drift", "env_logger 0.10.2", "fnv", + "futures", "futures-util", "hex", "hex-literal", diff --git a/Cargo.toml b/Cargo.toml index 1255a1e..3f83585 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,3 +57,4 @@ spl-associated-token-account = "1.1.2" anchor-client = "0.27.0" anchor-lang = "*" bytes = "*" +futures = "0.3.30" diff --git a/README.md b/README.md index 3995619..6b51c7b 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@
-

Drift Protocol v2 (Rust SDK)

+

drift-rs

Crates.io @@ -11,9 +11,9 @@

-# Drift Protocol v2 (Rust SDK) +# drift-rs -Rust SDK for building off chain clients for interacting with the [Drift V2](https://github.com/drift-labs/protocol-v2) protocol. +Experimental, high performance Rust SDK for building off chain clients for interacting with the [Drift V2](https://github.com/drift-labs/protocol-v2) protocol. ## Setup diff --git a/src/dlob/dlob.rs b/src/dlob/dlob.rs index eda1154..ad55014 100644 --- a/src/dlob/dlob.rs +++ b/src/dlob/dlob.rs @@ -14,7 +14,7 @@ use crate::dlob::dlob_node::{ use crate::dlob::market::{get_order_lists, Exchange, Market, OpenOrders, SubType}; use crate::event_emitter::Event; use crate::math::order::is_resting_limit_order; -use crate::usermap::Usermap; +use crate::usermap::UserMap; use crate::utils::market_type_to_string; #[derive(Clone)] @@ -43,7 +43,7 @@ impl DLOB { } } - pub fn build_from_usermap(&mut self, usermap: &Usermap, slot: u64) { + pub fn build_from_usermap(&mut self, usermap: &UserMap, slot: u64) { usermap.usermap.iter().par_bridge().for_each(|user_ref| { let user = user_ref.value(); let user_key = user_ref.key(); @@ -98,14 +98,14 @@ impl DLOB { .get_mut(&market_index) .expect(format!("Market index {} not found", market_index).as_str()); - let (order_list, subtype, node_type) = market.get_info_for_order_insert(&order, slot); + let (order_list, subtype, node_type) = market.get_info_for_order_insert(order, slot); - let node = create_node(node_type, order.clone(), user_account); + let node = create_node(node_type, *order, user_account); if let Some(order_list) = order_list { match subtype { - SubType::Bid => order_list.insert_bid(node.clone()), - SubType::Ask => order_list.insert_ask(node.clone()), + SubType::Bid => order_list.insert_bid(node), + SubType::Ask => order_list.insert_ask(node), _ => {} } } else { @@ -117,7 +117,7 @@ impl DLOB { let order_signature = get_order_signature(order_id, user_account); for order_list in get_order_lists(&self.exchange) { if let Some(node) = order_list.get_node(&order_signature) { - return Some(node.get_order().clone()); + return Some(*node.get_order()); } } diff --git a/src/dlob/dlob_builder.rs b/src/dlob/dlob_builder.rs index 9143c0a..3104acf 100644 --- a/src/dlob/dlob_builder.rs +++ b/src/dlob/dlob_builder.rs @@ -1,14 +1,11 @@ use crate::{ - dlob::dlob::DLOB, - event_emitter::{Event, EventEmitter}, - slot_subscriber::SlotSubscriber, - usermap::Usermap, - SdkResult, + dlob::dlob::DLOB, event_emitter::EventEmitter, slot_subscriber::SlotSubscriber, + usermap::UserMap, SdkResult, }; pub struct DLOBBuilder { slot_subscriber: SlotSubscriber, - usermap: Usermap, + usermap: UserMap, rebuild_frequency: u64, dlob: DLOB, event_emitter: EventEmitter, @@ -17,7 +14,7 @@ pub struct DLOBBuilder { impl DLOBBuilder { pub fn new( slot_subscriber: SlotSubscriber, - usermap: Usermap, + usermap: UserMap, rebuild_frequency: u64, ) -> DLOBBuilder { DLOBBuilder { @@ -56,6 +53,7 @@ impl DLOBBuilder { #[cfg(test)] mod tests { use super::*; + use crate::memcmp::get_user_with_order_filter; use crate::utils::get_ws_url; use solana_sdk::commitment_config::CommitmentConfig; use solana_sdk::commitment_config::CommitmentLevel; @@ -70,7 +68,12 @@ mod tests { }; let slot_subscriber = SlotSubscriber::new(get_ws_url(&endpoint.clone()).unwrap()); - let usermap = Usermap::new(commitment, endpoint, true); + let mut usermap = UserMap::new( + commitment, + endpoint, + true, + Some(vec![get_user_with_order_filter()]), + ); let mut dlob_builder = DLOBBuilder::new(slot_subscriber, usermap, 30); dlob_builder @@ -90,13 +93,18 @@ mod tests { #[tokio::test] #[cfg(rpc_tests)] async fn test_build_time() { - let endpoint = "url".to_string(); + let endpoint = "rpc".to_string(); let commitment = CommitmentConfig { commitment: CommitmentLevel::Processed, }; let mut slot_subscriber = SlotSubscriber::new(get_ws_url(&endpoint.clone()).unwrap()); - let mut usermap = Usermap::new(commitment, endpoint, true); + let mut usermap = UserMap::new( + commitment, + endpoint, + true, + Some(vec![get_user_with_order_filter()]), + ); let _ = slot_subscriber.subscribe().await; let _ = usermap.subscribe().await; diff --git a/src/dlob/dlob_node.rs b/src/dlob/dlob_node.rs index a3219d8..8a2972e 100644 --- a/src/dlob/dlob_node.rs +++ b/src/dlob/dlob_node.rs @@ -298,7 +298,7 @@ pub(crate) fn create_node(node_type: NodeType, order: Order, user_account: Pubke } pub(crate) fn get_order_signature(order_id: u32, user_account: Pubkey) -> String { - format!("{}-{}", order_id, user_account.to_string()) + format!("{}-{}", order_id, user_account) } #[cfg(test)] diff --git a/src/dlob/market.rs b/src/dlob/market.rs index ce1b759..0d5c409 100644 --- a/src/dlob/market.rs +++ b/src/dlob/market.rs @@ -61,12 +61,10 @@ impl Market { NodeType::Market } else if order.oracle_price_offset != 0 { NodeType::FloatingLimit + } else if is_resting_limit_order(order, slot) { + NodeType::RestingLimit } else { - if is_resting_limit_order(order, slot) { - NodeType::RestingLimit - } else { - NodeType::TakingLimit - } + NodeType::TakingLimit }; let order_list = match node_type { diff --git a/src/dlob/order_list.rs b/src/dlob/order_list.rs index 921af38..5a5c48e 100644 --- a/src/dlob/order_list.rs +++ b/src/dlob/order_list.rs @@ -26,20 +26,20 @@ impl Orderlist { pub fn insert_bid(&mut self, node: Node) { let order_sig = get_order_signature(node.get_order().order_id, node.get_user_account()); - self.order_sigs.insert(order_sig.clone(), node.clone()); + self.order_sigs.insert(order_sig.clone(), node); let directional = DirectionalNode::new(node, self.bid_sort_direction); self.bids.push(directional); } pub fn insert_ask(&mut self, node: Node) { let order_sig = get_order_signature(node.get_order().order_id, node.get_user_account()); - self.order_sigs.insert(order_sig.clone(), node.clone()); + self.order_sigs.insert(order_sig.clone(), node); let directional = DirectionalNode::new(node, self.ask_sort_direction); self.asks.push(directional); } pub fn get_best_bid(&mut self) -> Option { - if let Some(node) = self.bids.pop().map(|node| node.node.clone()) { + if let Some(node) = self.bids.pop().map(|node| node.node) { let order_sig = get_order_signature(node.get_order().order_id, node.get_user_account()); if self.order_sigs.contains_key(&order_sig) { self.order_sigs.remove(&order_sig); @@ -50,7 +50,7 @@ impl Orderlist { } pub fn get_best_ask(&mut self) -> Option { - if let Some(node) = self.asks.pop().map(|node| node.node.clone()) { + if let Some(node) = self.asks.pop().map(|node| node.node) { let order_sig = get_order_signature(node.get_order().order_id, node.get_user_account()); if self.order_sigs.contains_key(&order_sig) { self.order_sigs.remove(&order_sig); @@ -61,7 +61,7 @@ impl Orderlist { } pub fn get_node(&self, order_sig: &String) -> Option { - self.order_sigs.get(order_sig).map(|node| node.clone()) + self.order_sigs.get(order_sig).map(|node| *node) } pub fn bids_empty(&self) -> bool { diff --git a/src/lib.rs b/src/lib.rs index f4f22ea..d8e344e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ //! Drift SDK -use std::{borrow::Cow, sync::Arc, time::Duration}; +use std::{borrow::Cow, rc::Rc, sync::Arc, time::Duration}; use anchor_lang::{AccountDeserialize, Discriminator, InstructionData, ToAccountMetas}; use async_utils::{retry_policy, spawn_retry_task}; @@ -20,6 +20,8 @@ use drift::{ use fnv::FnvHashMap; use futures_util::{future::BoxFuture, FutureExt, StreamExt}; use log::{debug, warn}; +use marketmap::MarketMap; +use oraclemap::{OracleMap, OraclePriceDataAndSlot}; use solana_account_decoder::UiAccountEncoding; use solana_client::{ nonblocking::{pubsub_client::PubsubClient, rpc_client::RpcClient}, @@ -54,6 +56,7 @@ use crate::constants::{ // utils pub mod async_utils; +pub mod math; pub mod memcmp; pub mod utils; @@ -63,19 +66,21 @@ pub mod types; // internal infra pub mod event_emitter; +pub mod websocket_account_subscriber; pub mod websocket_program_account_subscriber; // subscribers pub mod auction_subscriber; pub mod dlob_client; pub mod event_subscriber; +pub mod marketmap; +pub mod oraclemap; #[cfg(feature = "jit")] pub mod jit_client; pub mod slot_subscriber; +pub mod usermap; pub mod dlob; -pub mod math; -pub mod usermap; use types::*; @@ -322,6 +327,18 @@ impl DriftClient { }) } + /// Subscribe to the Drift Client Backend + /// This is a no-op if already subscribed + pub async fn subscribe(&self) -> SdkResult<()> { + self.backend.subscribe().await + } + + /// Unsubscribe from the Drift Client Backend + /// This is a no-op if not subscribed + pub async fn unsubscribe(&self) -> SdkResult<()> { + self.backend.unsubscribe().await + } + /// Return a handle to the inner RPC client pub fn inner(&self) -> &RpcClient { self.backend.client() @@ -586,6 +603,63 @@ impl DriftClient { .get_recent_priority_fees(writable_markets, window) .await } + + pub fn get_perp_market_account_and_slot( + &self, + market_index: u16, + ) -> Option> { + self.backend.get_perp_market_account_and_slot(market_index) + } + + pub fn get_spot_market_account_and_slot( + &self, + market_index: u16, + ) -> Option> { + self.backend.get_spot_market_account_and_slot(market_index) + } + + pub fn get_perp_market_account(&self, market_index: u16) -> Option { + self.backend + .get_perp_market_account_and_slot(market_index) + .map(|x| x.data) + } + + pub fn get_spot_market_account(&self, market_index: u16) -> Option { + self.backend + .get_spot_market_account_and_slot(market_index) + .map(|x| x.data) + } + + pub fn num_perp_markets(&self) -> usize { + self.backend.num_perp_markets() + } + + pub fn num_spot_markets(&self) -> usize { + self.backend.num_spot_markets() + } + + pub fn get_oracle_price_data_and_slot( + &self, + oracle_pubkey: Pubkey, + ) -> Option { + self.backend.get_oracle_price_data_and_slot(oracle_pubkey) + } + + pub fn get_oracle_price_data_and_slot_for_perp_market( + &self, + market_index: u16, + ) -> Option { + self.backend + .get_oracle_price_data_and_slot_for_perp_market(market_index) + } + + pub fn get_oracle_price_data_and_slot_for_spot_market( + &self, + market_index: u16, + ) -> Option { + self.backend + .get_oracle_price_data_and_slot_for_spot_market(market_index) + } } /// Provides the heavy-lifting and network facing features of the SDK @@ -594,6 +668,9 @@ pub struct DriftClientBackend { rpc_client: RpcClient, account_provider: T, program_data: ProgramData, + perp_market_map: MarketMap, + spot_market_map: MarketMap, + oracle_map: Rc, } impl DriftClientBackend { @@ -604,10 +681,37 @@ impl DriftClientBackend { account_provider.commitment_config(), ); + let perp_market_map = MarketMap::::new( + account_provider.commitment_config(), + account_provider.endpoint(), + true, + ); + let spot_market_map = MarketMap::::new( + account_provider.commitment_config(), + account_provider.endpoint(), + true, + ); + + tokio::try_join!(perp_market_map.sync(), spot_market_map.sync())?; + + let perp_oracles = perp_market_map.oracles(); + let spot_oracles = spot_market_map.oracles(); + + let oracle_map = OracleMap::new( + account_provider.commitment_config(), + account_provider.endpoint(), + true, + perp_oracles, + spot_oracles, + ); + let mut this = Self { rpc_client, account_provider, program_data: ProgramData::uninitialized(), + perp_market_map, + spot_market_map, + oracle_map: Rc::new(oracle_map), }; let lookup_table_address = market_lookup_table(context); @@ -626,6 +730,101 @@ impl DriftClientBackend { Ok(this) } + async fn subscribe(&self) -> SdkResult<()> { + tokio::try_join!( + self.perp_market_map.subscribe(), + self.spot_market_map.subscribe(), + self.oracle_map.subscribe() + )?; + Ok(()) + } + + async fn unsubscribe(&self) -> SdkResult<()> { + tokio::try_join!( + self.perp_market_map.unsubscribe(), + self.spot_market_map.unsubscribe(), + self.oracle_map.unsubscribe() + )?; + Ok(()) + } + + fn get_perp_market_account_and_slot( + &self, + market_index: u16, + ) -> Option> { + self.perp_market_map.get(&market_index) + } + + fn get_spot_market_account_and_slot( + &self, + market_index: u16, + ) -> Option> { + self.spot_market_map.get(&market_index) + } + + fn num_perp_markets(&self) -> usize { + self.perp_market_map.size() + } + + fn num_spot_markets(&self) -> usize { + self.spot_market_map.size() + } + + fn get_oracle_price_data_and_slot( + &self, + oracle_pubkey: Pubkey, + ) -> Option { + self.oracle_map.get(&oracle_pubkey.to_string()) + } + + fn get_oracle_price_data_and_slot_for_perp_market( + &self, + market_index: u16, + ) -> Option { + let market = self.get_perp_market_account_and_slot(market_index)?; + + let oracle = market.data.amm.oracle; + let current_oracle = self + .oracle_map + .current_perp_oracle(market_index) + .expect("oracle"); + + if oracle != current_oracle { + let source = market.data.amm.oracle_source; + let clone = self.oracle_map.clone(); + tokio::task::spawn_local(async move { + let _ = clone.add_oracle(oracle, source).await; + clone.update_perp_oracle(market_index, oracle) + }); + } + + self.get_oracle_price_data_and_slot(current_oracle) + } + + fn get_oracle_price_data_and_slot_for_spot_market( + &self, + market_index: u16, + ) -> Option { + let market = self.get_spot_market_account_and_slot(market_index)?; + + let oracle = market.data.oracle; + let current_oracle = self + .oracle_map + .current_spot_oracle(market_index) + .expect("oracle"); + + if oracle != current_oracle { + let source = market.data.oracle_source; + let clone = self.oracle_map.clone(); + tokio::task::spawn_local(async move { + let _ = clone.add_oracle(oracle, source).await; + clone.update_spot_oracle(market_index, oracle); + }); + } + + self.get_oracle_price_data_and_slot(market.data.oracle) + } + /// Return a handle to the inner RPC client fn client(&self) -> &RpcClient { &self.rpc_client @@ -1612,6 +1811,17 @@ mod tests { account_provider_mocks: Mocks, keypair: Keypair, ) -> DriftClient { + let perp_market_map = MarketMap::::new( + CommitmentConfig::processed(), + DEVNET_ENDPOINT.to_string(), + false, + ); + let spot_market_map = MarketMap::::new( + CommitmentConfig::processed(), + DEVNET_ENDPOINT.to_string(), + false, + ); + let backend = DriftClientBackend { rpc_client: RpcClient::new_mock_with_mocks(DEVNET_ENDPOINT.to_string(), rpc_mocks), account_provider: RpcAccountProvider { @@ -1621,6 +1831,15 @@ mod tests { )), }, program_data: ProgramData::uninitialized(), + perp_market_map, + spot_market_map, + oracle_map: Rc::new(OracleMap::new( + CommitmentConfig::processed(), + DEVNET_ENDPOINT.to_string(), + true, + vec![], + vec![], + )), }; DriftClient { @@ -1631,6 +1850,38 @@ mod tests { } } + #[tokio::test] + #[cfg(rpc_tests)] + async fn test_marketmap_subscribe() { + let endpoint = "rpc"; + + let client = DriftClient::new( + Context::MainNet, + RpcAccountProvider::new(endpoint), + Keypair::new().into(), + ) + .await + .unwrap(); + + let _ = client.subscribe().await; + + tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; + + for _ in 0..20 { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + let perp_market = client.get_perp_market_account_and_slot(0); + let slot = perp_market.unwrap().slot; + dbg!(slot); + } + + for _ in 0..20 { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + let spot_market = client.get_spot_market_account_and_slot(0); + let slot = spot_market.unwrap().slot; + dbg!(slot); + } + } + #[tokio::test] async fn get_market_accounts() { let client = DriftClient::new( diff --git a/src/marketmap.rs b/src/marketmap.rs new file mode 100644 index 0000000..9fc9034 --- /dev/null +++ b/src/marketmap.rs @@ -0,0 +1,282 @@ +use std::cell::{Cell, RefCell}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; + +use crate::event_emitter::EventEmitter; +use crate::memcmp::get_market_filter; +use crate::utils::{decode, get_ws_url}; +use crate::websocket_program_account_subscriber::{ + ProgramAccountUpdate, WebsocketProgramAccountOptions, WebsocketProgramAccountSubscriber, +}; +use crate::{DataAndSlot, SdkResult}; +use anchor_lang::AccountDeserialize; +use dashmap::DashMap; +use drift::state::oracle::OracleSource; +use drift::state::perp_market::PerpMarket; +use drift::state::spot_market::SpotMarket; +use drift::state::user::MarketType; +use serde_json::json; +use solana_account_decoder::UiAccountEncoding; +use solana_client::nonblocking::rpc_client::RpcClient; +use solana_client::rpc_config::{RpcAccountInfoConfig, RpcProgramAccountsConfig}; +use solana_client::rpc_request::RpcRequest; +use solana_client::rpc_response::{OptionalContext, RpcKeyedAccount}; +use solana_sdk::commitment_config::CommitmentConfig; +use solana_sdk::pubkey::Pubkey; + +pub trait Market { + const MARKET_TYPE: MarketType; + fn market_index(&self) -> u16; + fn oracle_info(&self) -> (u16, Pubkey, OracleSource); +} + +impl Market for PerpMarket { + const MARKET_TYPE: MarketType = MarketType::Perp; + + fn market_index(&self) -> u16 { + self.market_index + } + + fn oracle_info(&self) -> (u16, Pubkey, OracleSource) { + (self.market_index(), self.amm.oracle, self.amm.oracle_source) + } +} + +impl Market for SpotMarket { + const MARKET_TYPE: MarketType = MarketType::Spot; + + fn market_index(&self) -> u16 { + self.market_index + } + + fn oracle_info(&self) -> (u16, Pubkey, OracleSource) { + (self.market_index(), self.oracle, self.oracle_source) + } +} + +pub struct MarketMap { + subscribed: Cell, + subscription: RefCell, + marketmap: Arc>>, + sync_lock: Option>, + latest_slot: Arc, + commitment: CommitmentConfig, + rpc: RpcClient, + synced: bool, +} + +impl MarketMap { + pub fn new(commitment: CommitmentConfig, endpoint: String, sync: bool) -> Self { + let filters = vec![get_market_filter(T::MARKET_TYPE)]; + let options = WebsocketProgramAccountOptions { + filters, + commitment, + encoding: UiAccountEncoding::Base64, + }; + let event_emitter = EventEmitter::new(); + + let url = get_ws_url(&endpoint.clone()).unwrap(); + + let subscription = + WebsocketProgramAccountSubscriber::new("marketmap", url, options, event_emitter); + + let marketmap = Arc::new(DashMap::new()); + + let rpc = RpcClient::new_with_commitment(endpoint.clone(), commitment); + + let sync_lock = if sync { Some(Mutex::new(())) } else { None }; + + Self { + subscribed: Cell::new(false), + subscription: RefCell::new(subscription), + marketmap, + sync_lock, + latest_slot: Arc::new(AtomicU64::new(0)), + commitment, + rpc, + synced: false, + } + } + + pub async fn subscribe(&self) -> SdkResult<()> { + if self.sync_lock.is_some() { + self.sync().await?; + } + + if !self.subscribed.get() { + self.subscription.try_borrow_mut()?.subscribe::().await?; + self.subscribed.set(true); + + let marketmap = self.marketmap.clone(); + let latest_slot = self.latest_slot.clone(); + + self.subscription + .try_borrow()? + .event_emitter + .subscribe("marketmap", move |event| { + if let Some(update) = event.as_any().downcast_ref::>() { + let market_data_and_slot = update.data_and_slot.clone(); + if update.data_and_slot.slot > latest_slot.load(Ordering::Relaxed) { + latest_slot.store(update.data_and_slot.slot, Ordering::Relaxed); + } + marketmap.insert( + update.data_and_slot.clone().data.market_index(), + DataAndSlot { + data: market_data_and_slot.data, + slot: update.data_and_slot.slot, + }, + ); + } + }); + } + Ok(()) + } + + pub async fn unsubscribe(&self) -> SdkResult<()> { + if self.subscribed.get() { + self.subscription.try_borrow_mut()?.unsubscribe().await?; + self.subscribed.set(false); + self.marketmap.clear(); + self.latest_slot.store(0, Ordering::Relaxed); + } + Ok(()) + } + + pub fn values(&self) -> Vec { + self.marketmap.iter().map(|x| x.data.clone()).collect() + } + + pub fn oracles(&self) -> Vec<(u16, Pubkey, OracleSource)> { + self.values().iter().map(|x| x.oracle_info()).collect() + } + + pub fn size(&self) -> usize { + self.marketmap.len() + } + + pub fn contains(&self, market_index: &u16) -> bool { + self.marketmap.contains_key(market_index) + } + + pub fn get(&self, market_index: &u16) -> Option> { + self.marketmap + .get(market_index) + .map(|market| market.clone()) + } + + pub(crate) async fn sync(&self) -> SdkResult<()> { + if self.synced { + return Ok(()); + } + + let sync_lock = self.sync_lock.as_ref().expect("expected sync lock"); + + let lock = match sync_lock.try_lock() { + Ok(lock) => lock, + Err(_) => return Ok(()), + }; + + let options = self.subscription.try_borrow()?.options.clone(); + + let account_config = RpcAccountInfoConfig { + commitment: Some(self.commitment), + encoding: Some(options.encoding), + ..RpcAccountInfoConfig::default() + }; + + let gpa_config = RpcProgramAccountsConfig { + filters: Some(options.filters), + account_config, + with_context: Some(true), + }; + + let response = self + .rpc + .send::>>( + RpcRequest::GetProgramAccounts, + json!([drift::id().to_string(), gpa_config]), + ) + .await?; + + if let OptionalContext::Context(accounts) = response { + for account in accounts.value { + let slot = accounts.context.slot; + let market_data = account.account.data; + let data = decode::(market_data)?; + self.marketmap + .insert(data.market_index(), DataAndSlot { data, slot }); + } + + self.latest_slot + .store(accounts.context.slot, Ordering::Relaxed); + } + + drop(lock); + Ok(()) + } + + pub fn get_latest_slot(&self) -> u64 { + self.latest_slot.load(Ordering::Relaxed) + } +} + +#[cfg(test)] +mod tests { + use crate::marketmap::MarketMap; + use drift::state::perp_market::PerpMarket; + use drift::state::spot_market::SpotMarket; + use solana_sdk::commitment_config::CommitmentConfig; + use solana_sdk::commitment_config::CommitmentLevel; + + #[tokio::test] + #[cfg(rpc_tests)] + async fn test_marketmap_perp() { + let endpoint = "rpc".to_string(); + let commitment = CommitmentConfig { + commitment: CommitmentLevel::Processed, + }; + + let marketmap = MarketMap::::new(commitment, endpoint, true); + marketmap.subscribe().await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; + + dbg!(marketmap.size()); + assert!(marketmap.size() == 28); + + dbg!(marketmap.get_latest_slot()); + + marketmap.unsubscribe().await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; + + assert_eq!(marketmap.size(), 0); + assert_eq!(marketmap.subscribed.get(), false); + } + + #[tokio::test] + #[cfg(rpc_tests)] + async fn test_marketmap_spot() { + let endpoint = "rpc".to_string(); + let commitment = CommitmentConfig { + commitment: CommitmentLevel::Processed, + }; + + let marketmap = MarketMap::::new(commitment, endpoint, true); + marketmap.subscribe().await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; + + dbg!(marketmap.size()); + assert!(marketmap.size() == 13); + + dbg!(marketmap.get_latest_slot()); + + marketmap.unsubscribe().await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; + + assert_eq!(marketmap.size(), 0); + assert_eq!(marketmap.subscribed.get(), false); + } +} diff --git a/src/memcmp.rs b/src/memcmp.rs index 958f670..a89ed11 100644 --- a/src/memcmp.rs +++ b/src/memcmp.rs @@ -1,5 +1,7 @@ use anchor_lang::Discriminator; -use drift::state::user::User; +use drift::state::perp_market::PerpMarket; +use drift::state::spot_market::SpotMarket; +use drift::state::user::{MarketType, User}; use solana_client::rpc_filter::{Memcmp, RpcFilterType}; pub fn get_user_filter() -> RpcFilterType { @@ -7,7 +9,7 @@ pub fn get_user_filter() -> RpcFilterType { } pub fn get_non_idle_user_filter() -> RpcFilterType { - RpcFilterType::Memcmp(Memcmp::new_raw_bytes(4_350, vec![1])) + RpcFilterType::Memcmp(Memcmp::new_raw_bytes(4_350, vec![0])) } pub fn get_user_with_auction_filter() -> RpcFilterType { @@ -17,3 +19,14 @@ pub fn get_user_with_auction_filter() -> RpcFilterType { pub fn get_user_with_order_filter() -> RpcFilterType { RpcFilterType::Memcmp(Memcmp::new_raw_bytes(4_352, vec![1])) } + +pub fn get_market_filter(market_type: MarketType) -> RpcFilterType { + match market_type { + MarketType::Spot => { + RpcFilterType::Memcmp(Memcmp::new_raw_bytes(0, SpotMarket::discriminator().into())) + } + MarketType::Perp => { + RpcFilterType::Memcmp(Memcmp::new_raw_bytes(0, PerpMarket::discriminator().into())) + } + } +} diff --git a/src/oraclemap.rs b/src/oraclemap.rs new file mode 100644 index 0000000..7d01f35 --- /dev/null +++ b/src/oraclemap.rs @@ -0,0 +1,465 @@ +use crate::utils::get_ws_url; +use crate::websocket_account_subscriber::{AccountUpdate, WebsocketAccountSubscriber}; +use crate::{event_emitter::EventEmitter, SdkResult}; +use dashmap::DashMap; +use drift::state::oracle::{get_oracle_price, OraclePriceData, OracleSource}; +use solana_account_decoder::{UiAccountData, UiAccountEncoding}; +use solana_client::nonblocking::rpc_client::RpcClient; +use solana_client::rpc_config::RpcAccountInfoConfig; +use solana_sdk::account_info::{AccountInfo, IntoAccountInfo}; +use solana_sdk::{commitment_config::CommitmentConfig, pubkey::Pubkey}; +use std::cell::{Cell, RefCell}; +use std::str::FromStr; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; + +#[derive(Copy, Clone, Debug)] +pub struct OraclePriceDataAndSlot { + pub data: OraclePriceData, + pub slot: u64, +} + +pub(crate) struct OracleMap { + subscribed: Cell, + pub(crate) oraclemap: Arc>, + event_emitter: &'static EventEmitter, + oracle_infos: DashMap, + sync_lock: Option>, + latest_slot: Arc, + commitment: CommitmentConfig, + rpc: RpcClient, + oracle_subscribers: RefCell>, + perp_oracles: DashMap, + spot_oracles: DashMap, +} + +impl OracleMap { + pub fn new( + commitment: CommitmentConfig, + endpoint: String, + sync: bool, + perp_oracles: Vec<(u16, Pubkey, OracleSource)>, + spot_oracles: Vec<(u16, Pubkey, OracleSource)>, + ) -> Self { + let oraclemap = Arc::new(DashMap::new()); + + let event_emitter = EventEmitter::new(); + + let rpc = RpcClient::new_with_commitment(endpoint.clone(), commitment); + + let sync_lock = if sync { Some(Mutex::new(())) } else { None }; + + let mut all_oracles = vec![]; + all_oracles.extend(perp_oracles.clone()); + all_oracles.extend(spot_oracles.clone()); + + let oracle_infos_map: DashMap<_, _> = all_oracles + .iter() + .map(|(_, pubkey, oracle_source)| (*pubkey, *oracle_source)) + .collect(); + + let perp_oracles_map: DashMap<_, _> = perp_oracles + .iter() + .map(|(market_index, pubkey, _)| (*market_index, *pubkey)) + .collect(); + + let spot_oracles_map: DashMap<_, _> = spot_oracles + .iter() + .map(|(market_index, pubkey, _)| (*market_index, *pubkey)) + .collect(); + + Self { + subscribed: Cell::new(false), + oraclemap, + oracle_infos: oracle_infos_map, + sync_lock, + latest_slot: Arc::new(AtomicU64::new(0)), + commitment, + event_emitter: Box::leak(Box::new(event_emitter)), + rpc, + oracle_subscribers: RefCell::new(vec![]), + perp_oracles: perp_oracles_map, + spot_oracles: spot_oracles_map, + } + } + + pub async fn subscribe(&self) -> SdkResult<()> { + if self.sync_lock.is_some() { + self.sync().await?; + } + + if !self.subscribed.get() { + let url = get_ws_url(&self.rpc.url()).expect("valid url"); + let subscription_name: &'static str = "oraclemap"; + + let mut oracle_subscribers = vec![]; + for oracle_info in self.oracle_infos.iter() { + let oracle_pubkey = oracle_info.key(); + let oracle_subscriber = WebsocketAccountSubscriber::new( + subscription_name, + url.clone(), + *oracle_pubkey, + self.commitment, + self.event_emitter.clone(), + ); + oracle_subscribers.push(oracle_subscriber); + } + + self.subscribed.set(true); + + let oracle_source_by_oracle_key = self.oracle_infos.clone(); + let oracle_map = self.oraclemap.clone(); + + self.event_emitter.subscribe("oraclemap", move |event| { + if let Some(update) = event.as_any().downcast_ref::() { + let oracle_pubkey = Pubkey::from_str(&update.pubkey).expect("valid pubkey"); + let oracle_source_maybe = oracle_source_by_oracle_key.get(&oracle_pubkey); + if let Some(oracle_source) = oracle_source_maybe { + if let UiAccountData::Binary(blob, UiAccountEncoding::Base64) = + &update.data.data + { + let mut data = base64::decode(blob).expect("valid data"); + let owner = Pubkey::from_str(&update.data.owner).expect("valid pubkey"); + let mut lamports = update.data.lamports; + let oracle_account_info = AccountInfo::new( + &oracle_pubkey, + false, + false, + &mut lamports, + &mut data, + &owner, + false, + update.data.rent_epoch, + ); + match get_oracle_price( + oracle_source.value(), + &oracle_account_info, + update.slot, + ) { + Ok(price_data) => { + oracle_map.insert( + update.pubkey.clone(), + OraclePriceDataAndSlot { + data: price_data, + slot: update.slot, + }, + ); + } + Err(err) => { + log::error!("Failed to get oracle price: {:?}", err) + } + } + } + } + } + }); + + let mut subscribers_clone = oracle_subscribers.clone(); + + let subscribe_futures = subscribers_clone + .iter_mut() + .map(|subscriber| subscriber.subscribe()) + .collect::>(); + let results = futures_util::future::join_all(subscribe_futures).await; + results.into_iter().collect::, _>>()?; + + let mut oracle_subscribers_mut = self.oracle_subscribers.try_borrow_mut()?; + *oracle_subscribers_mut = oracle_subscribers; + } + + Ok(()) + } + + pub async fn unsubscribe(&self) -> SdkResult<()> { + if self.subscribed.get() { + let mut oracle_subscribers = self.oracle_subscribers.try_borrow_mut()?; + let unsubscribe_futures = oracle_subscribers + .iter_mut() + .map(|subscriber| subscriber.unsubscribe()) + .collect::>(); + + let results = futures_util::future::join_all(unsubscribe_futures).await; + results.into_iter().collect::, _>>()?; + self.subscribed.set(false); + self.oraclemap.clear(); + self.latest_slot.store(0, Ordering::Relaxed); + } + Ok(()) + } + + async fn sync(&self) -> SdkResult<()> { + let sync_lock = self.sync_lock.as_ref().expect("expected sync lock"); + + let lock = match sync_lock.try_lock() { + Ok(lock) => lock, + Err(_) => return Ok(()), + }; + + let account_config = RpcAccountInfoConfig { + commitment: Some(self.commitment), + encoding: None, + ..RpcAccountInfoConfig::default() + }; + + let mut pubkeys = self + .oracle_infos + .iter() + .map(|oracle_info_ref| *oracle_info_ref.key()) + .collect::>(); + pubkeys.sort(); + + let mut oracle_infos = self + .oracle_infos + .iter() + .map(|oracle_info_ref| (*oracle_info_ref.key(), *oracle_info_ref.value())) + .collect::>(); + oracle_infos.sort_by_key(|key| key.0); + + let response = self + .rpc + .get_multiple_accounts_with_config(&pubkeys, account_config) + .await?; + + if response.value.len() != pubkeys.len() { + return Err(crate::SdkError::Generic(format!( + "failed to get all oracle accounts, expected: {}, got: {}", + pubkeys.len(), + response.value.len() + ))); + } + + let slot = response.context.slot; + + for (account, oracle_info) in response.value.iter().zip(oracle_infos.iter()) { + if let Some(oracle_account) = account { + let oracle_pubkey = oracle_info.0; + let mut oracle_components = (oracle_pubkey, oracle_account.clone()); + let account_info = oracle_components.into_account_info(); + let price_data = get_oracle_price(&oracle_info.1, &account_info, slot) + .map_err(|err| crate::SdkError::Anchor(Box::new(err.into())))?; + self.oraclemap.insert( + oracle_pubkey.to_string(), + OraclePriceDataAndSlot { + data: price_data, + slot, + }, + ); + } + } + + self.latest_slot.store(slot, Ordering::Relaxed); + + drop(lock); + + Ok(()) + } + + pub fn size(&self) -> usize { + self.oraclemap.len() + } + + pub fn contains(&self, key: &Pubkey) -> bool { + self.oracle_infos.contains_key(key) + } + + pub fn current_perp_oracle(&self, market_index: u16) -> Option { + self.perp_oracles.get(&market_index).map(|x| *x) + } + + pub fn current_spot_oracle(&self, market_index: u16) -> Option { + self.spot_oracles.get(&market_index).map(|x| *x) + } + + pub fn get(&self, key: &str) -> Option { + self.oraclemap.get(key).map(|v| *v) + } + + pub fn values(&self) -> Vec { + self.oraclemap.iter().map(|x| x.value().data).collect() + } + + pub async fn add_oracle(&self, oracle: Pubkey, source: OracleSource) -> SdkResult<()> { + if self.contains(&oracle) { + return Ok(()); // don't add a duplicate + } + + self.oracle_infos.insert(oracle, source); + + let mut new_oracle_subscriber = WebsocketAccountSubscriber::new( + "oraclemap", + get_ws_url(&self.rpc.url()).expect("valid url"), + oracle, + self.commitment, + self.event_emitter.clone(), + ); + + new_oracle_subscriber.subscribe().await?; + let mut oracle_subscribers = self.oracle_subscribers.try_borrow_mut()?; + oracle_subscribers.push(new_oracle_subscriber); + + Ok(()) + } + + pub fn update_spot_oracle(&self, market_index: u16, oracle: Pubkey) { + self.spot_oracles.insert(market_index, oracle); + } + + pub fn update_perp_oracle(&self, market_index: u16, oracle: Pubkey) { + self.perp_oracles.insert(market_index, oracle); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::marketmap::MarketMap; + use drift::state::perp_market::PerpMarket; + use drift::state::spot_market::SpotMarket; + + #[tokio::test] + #[cfg(rpc_tests)] + async fn test_oracle_map() { + let commitment = CommitmentConfig::processed(); + let endpoint = "rpc".to_string(); + + let spot_market_map = + MarketMap::::new(commitment.clone(), endpoint.clone(), true); + let perp_market_map = + MarketMap::::new(commitment.clone(), endpoint.clone(), true); + + let _ = spot_market_map.sync().await; + let _ = perp_market_map.sync().await; + + let perp_oracles = perp_market_map.oracles(); + let spot_oracles = spot_market_map.oracles(); + + let mut oracles = vec![]; + oracles.extend(perp_oracles.clone()); + oracles.extend(spot_oracles.clone()); + + let mut oracle_infos = vec![]; + for oracle_info in oracles { + if !oracle_infos.contains(&oracle_info) { + oracle_infos.push(oracle_info) + } + } + + let oracle_infos_len = oracle_infos.len(); + dbg!(oracle_infos_len); + + let oracle_map = OracleMap::new(commitment, endpoint, true, perp_oracles, spot_oracles); + + let _ = oracle_map.subscribe().await; + + dbg!(oracle_map.size()); + // assert_eq!(oracle_map.size(), oracle_infos_len); + + dbg!("sleeping"); + tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; + dbg!("done sleeping"); + + dbg!("perp market oracles"); + let mut last_sol_price = 0; + let mut last_sol_slot = 0; + let mut last_btc_price = 0; + let mut last_btc_slot = 0; + for _ in 0..10 { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + dbg!(); + let sol_perp_market_oracle_pubkey = perp_market_map + .get(&0) + .expect("sol perp market") + .data + .amm + .oracle; + let sol_oracle = oracle_map + .get(&sol_perp_market_oracle_pubkey.to_string()) + .expect("sol oracle"); + dbg!("sol oracle info:"); + dbg!(sol_oracle.data.price); + dbg!(sol_oracle.slot); + dbg!( + "sol price change: {}", + sol_oracle.data.price - last_sol_price + ); + dbg!("sol slot change: {}", sol_oracle.slot - last_sol_slot); + last_sol_price = sol_oracle.data.price; + last_sol_slot = sol_oracle.slot; + + dbg!(); + + let btc_perp_market_oracle_pubkey = perp_market_map + .get(&1) + .expect("btc perp market") + .data + .amm + .oracle; + let btc_oracle = oracle_map + .get(&btc_perp_market_oracle_pubkey.to_string()) + .expect("btc oracle"); + dbg!("btc oracle info:"); + dbg!(btc_oracle.data.price); + dbg!(btc_oracle.slot); + dbg!( + "btc price change: {}", + btc_oracle.data.price - last_btc_price + ); + dbg!("btc slot change: {}", btc_oracle.slot - last_btc_slot); + last_btc_price = btc_oracle.data.price; + last_btc_slot = btc_oracle.slot; + } + + dbg!(); + + dbg!("spot market oracles"); + let mut last_rndr_price = 0; + let mut last_rndr_slot = 0; + let mut last_weth_price = 0; + let mut last_weth_slot = 0; + for _ in 0..10 { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + dbg!(); + let rndr_spot_market_oracle_pubkey = spot_market_map + .get(&11) + .expect("sol perp market") + .data + .oracle; + let rndr_oracle = oracle_map + .get(&rndr_spot_market_oracle_pubkey.to_string()) + .expect("sol oracle"); + dbg!("rndr oracle info:"); + dbg!(rndr_oracle.data.price); + dbg!(rndr_oracle.slot); + dbg!( + "rndr price change: {}", + rndr_oracle.data.price - last_rndr_price + ); + dbg!("rndr slot change: {}", rndr_oracle.slot - last_rndr_slot); + last_rndr_price = rndr_oracle.data.price; + last_rndr_slot = rndr_oracle.slot; + + dbg!(); + + let weth_spot_market_oracle_pubkey = spot_market_map + .get(&4) + .expect("sol perp market") + .data + .oracle; + let weth_oracle = oracle_map + .get(&weth_spot_market_oracle_pubkey.to_string()) + .expect("sol oracle"); + dbg!("weth oracle info:"); + dbg!(weth_oracle.data.price); + dbg!(weth_oracle.slot); + dbg!( + "weth price change: {}", + weth_oracle.data.price - last_weth_price + ); + dbg!("weth slot change: {}", weth_oracle.slot - last_weth_slot); + last_weth_price = weth_oracle.data.price; + last_weth_slot = weth_oracle.slot; + } + + let _ = oracle_map.unsubscribe().await; + } +} diff --git a/src/types.rs b/src/types.rs index 45f5783..f410ea2 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,4 +1,7 @@ -use std::cmp::Ordering; +use std::{ + cell::{BorrowError, BorrowMutError}, + cmp::Ordering, +}; use anchor_lang::AccountDeserialize; use drift::{error::ErrorCode, state::user::UserStats}; @@ -240,6 +243,14 @@ pub enum SdkError { CouldntUnsubscribe(#[from] tokio::sync::mpsc::error::SendError<()>), #[error("MathError")] MathError(String), + #[error("{0}")] + BorrowMutError(#[from] BorrowMutError), + #[error("{0}")] + BorrowError(#[from] BorrowError), + #[error("{0}")] + Generic(String), + #[error("max connection attempts reached")] + MaxReconnectionAttemptsReached, } impl SdkError { diff --git a/src/usermap.rs b/src/usermap.rs index ab93985..1677408 100644 --- a/src/usermap.rs +++ b/src/usermap.rs @@ -16,12 +16,13 @@ use serde_json::json; use solana_account_decoder::UiAccountEncoding; use solana_client::nonblocking::rpc_client::RpcClient; use solana_client::rpc_config::{RpcAccountInfoConfig, RpcProgramAccountsConfig}; +use solana_client::rpc_filter::RpcFilterType; use solana_client::rpc_request::RpcRequest; use solana_client::rpc_response::{OptionalContext, RpcKeyedAccount}; use solana_sdk::commitment_config::CommitmentConfig; use solana_sdk::pubkey::Pubkey; -pub struct Usermap { +pub struct UserMap { subscribed: bool, subscription: WebsocketProgramAccountSubscriber, pub(crate) usermap: Arc>, @@ -31,9 +32,15 @@ pub struct Usermap { rpc: RpcClient, } -impl Usermap { - pub fn new(commitment: CommitmentConfig, endpoint: String, sync: bool) -> Self { - let filters = vec![get_user_filter(), get_non_idle_user_filter()]; +impl UserMap { + pub fn new( + commitment: CommitmentConfig, + endpoint: String, + sync: bool, + additional_filters: Option>, + ) -> Self { + let mut filters = vec![get_user_filter(), get_non_idle_user_filter()]; + filters.extend(additional_filters.unwrap_or_default()); let options = WebsocketProgramAccountOptions { filters, commitment, @@ -64,7 +71,7 @@ impl Usermap { } pub async fn subscribe(&mut self) -> SdkResult<()> { - if let Some(_) = self.sync_lock { + if self.sync_lock.is_some() { self.sync().await?; } @@ -123,7 +130,7 @@ impl Usermap { .get_account_data(&Pubkey::from_str(pubkey).unwrap()) .await?; let user = User::try_deserialize(&mut user_data.as_slice()).unwrap(); - self.usermap.insert(pubkey.to_string(), user.clone()); + self.usermap.insert(pubkey.to_string(), user); Ok(self.get(pubkey).unwrap()) } } @@ -183,7 +190,7 @@ mod tests { #[tokio::test] #[cfg(rpc_tests)] async fn test_usermap() { - use crate::usermap::Usermap; + use crate::usermap::UserMap; use solana_sdk::commitment_config::CommitmentConfig; use solana_sdk::commitment_config::CommitmentLevel; @@ -192,7 +199,7 @@ mod tests { commitment: CommitmentLevel::Processed, }; - let mut usermap = Usermap::new(commitment, endpoint, true); + let mut usermap = UserMap::new(commitment, endpoint, true); usermap.subscribe().await.unwrap(); tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; diff --git a/src/websocket_account_subscriber.rs b/src/websocket_account_subscriber.rs new file mode 100644 index 0000000..65ade3f --- /dev/null +++ b/src/websocket_account_subscriber.rs @@ -0,0 +1,171 @@ +use futures_util::StreamExt; +use solana_account_decoder::{UiAccount, UiAccountEncoding}; +use solana_client::{nonblocking::pubsub_client::PubsubClient, rpc_config::RpcAccountInfoConfig}; +use solana_sdk::{commitment_config::CommitmentConfig, pubkey::Pubkey}; + +use crate::{ + event_emitter::{Event, EventEmitter}, + SdkResult, +}; + +#[derive(Clone, Debug)] +pub(crate) struct AccountUpdate { + pub pubkey: String, + pub data: UiAccount, + pub slot: u64, +} + +impl Event for AccountUpdate { + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +#[derive(Clone)] +pub struct WebsocketAccountSubscriber { + subscription_name: &'static str, + url: String, + pubkey: Pubkey, + pub(crate) commitment: CommitmentConfig, + pub subscribed: bool, + pub event_emitter: EventEmitter, + unsubscriber: Option>, +} + +impl WebsocketAccountSubscriber { + pub fn new( + subscription_name: &'static str, + url: String, + pubkey: Pubkey, + commitment: CommitmentConfig, + event_emitter: EventEmitter, + ) -> Self { + WebsocketAccountSubscriber { + subscription_name, + url, + pubkey, + commitment, + subscribed: false, + event_emitter, + unsubscriber: None, + } + } + + pub async fn subscribe(&mut self) -> SdkResult<()> { + if self.subscribed { + return Ok(()); + } + + self.subscribed = true; + self.subscribe_ws().await?; + Ok(()) + } + + async fn subscribe_ws(&mut self) -> SdkResult<()> { + let account_config = RpcAccountInfoConfig { + commitment: Some(self.commitment), + encoding: Some(UiAccountEncoding::Base64), + ..RpcAccountInfoConfig::default() + }; + let (unsub_tx, mut unsub_rx) = tokio::sync::mpsc::channel::<()>(1); + self.unsubscriber = Some(unsub_tx); + + let mut attempt = 0; + let max_reconnection_attempts = 20; + let base_delay = tokio::time::Duration::from_secs(2); + + let url = self.url.clone(); + + tokio::spawn({ + let event_emitter = self.event_emitter.clone(); + let mut latest_slot = 0; + let subscription_name = self.subscription_name; + let pubkey = self.pubkey.clone(); + async move { + loop { + let pubsub = PubsubClient::new(&url).await?; + + match pubsub + .account_subscribe(&pubkey, Some(account_config.clone())) + .await + { + Ok((mut account_updates, account_unsubscribe)) => loop { + tokio::select! { + message = account_updates.next() => { + match message { + Some(message) => { + let slot = message.context.slot; + if slot >= latest_slot { + latest_slot = slot; + let account_update = AccountUpdate { + pubkey: pubkey.to_string(), + data: message.value, + slot, + }; + event_emitter.emit(subscription_name, Box::new(account_update)); + } + } + None => { + log::warn!("{}: Account stream interrupted", subscription_name); + account_unsubscribe().await; + break; + } + } + } + unsub = unsub_rx.recv() => { + if let Some(_) = unsub { + log::debug!("{}: Unsubscribing from account stream", subscription_name); + account_unsubscribe().await; + return Ok(()); + + } + } + } + }, + Err(_) => { + log::error!( + "{}: Failed to subscribe to account stream, retrying", + subscription_name + ); + attempt += 1; + if attempt >= max_reconnection_attempts { + log::error!("Max reconnection attempts reached."); + return Err(crate::SdkError::MaxReconnectionAttemptsReached); + } + } + } + + if attempt >= max_reconnection_attempts { + log::error!("{}: Max reconnection attempts reached", subscription_name); + return Err(crate::SdkError::MaxReconnectionAttemptsReached); + } + + let delay_duration = base_delay * 2_u32.pow(attempt); + log::debug!( + "{}: Reconnecting in {:?}", + subscription_name, + delay_duration + ); + tokio::time::sleep(delay_duration).await; + attempt += 1; + } + } + }); + Ok(()) + } + + pub async fn unsubscribe(&mut self) -> SdkResult<()> { + if self.subscribed && self.unsubscriber.is_some() { + if let Err(e) = self.unsubscriber.as_ref().unwrap().send(()).await { + log::error!("Failed to send unsubscribe signal: {:?}", e); + return Err(crate::SdkError::CouldntUnsubscribe(e)); + } + self.subscribed = false; + } + Ok(()) + } +} diff --git a/src/websocket_program_account_subscriber.rs b/src/websocket_program_account_subscriber.rs index d8be2bb..5f706c2 100644 --- a/src/websocket_program_account_subscriber.rs +++ b/src/websocket_program_account_subscriber.rs @@ -42,6 +42,7 @@ impl Event for ProgramAccountUpd } } +#[derive(Clone)] pub struct WebsocketProgramAccountOptions { pub filters: Vec, pub commitment: CommitmentConfig, @@ -102,53 +103,82 @@ impl WebsocketProgramAccountSubscriber { ..RpcProgramAccountsConfig::default() }; - let pubsub = PubsubClient::new(&self.url).await?; let (unsub_tx, mut unsub_rx) = tokio::sync::mpsc::channel::<()>(1); self.unsubscriber = Some(unsub_tx); + let mut attempt = 0; + let max_reconnection_attempts = 20; + let base_delay = tokio::time::Duration::from_secs(5); + + let url = self.url.clone(); tokio::spawn({ let event_emitter = self.event_emitter.clone(); let mut latest_slot = 0; let subscription_name = self.subscription_name; async move { - let (mut accounts, unsubscriber) = pubsub - .program_subscribe(&drift::ID, Some(config)) - .await - .unwrap(); loop { - tokio::select! { - message = accounts.next() => { - match message { - Some(message) => { - let slot = message.context.slot; - if slot >= latest_slot { - latest_slot = slot; - let pubkey = message.value.pubkey; - let account_data = message.value.account.data; - match decode(account_data) { - Ok(data) => { - let data_and_slot = DataAndSlot:: { slot, data }; - event_emitter.emit(subscription_name, Box::new(ProgramAccountUpdate::new(pubkey, data_and_slot))); - }, - Err(e) => { - error!("Error decoding account data {e}"); + let pubsub = PubsubClient::new(&url).await?; + match pubsub + .program_subscribe(&drift::ID, Some(config.clone())) + .await + { + Ok((mut accounts, unsubscriber)) => loop { + tokio::select! { + message = accounts.next() => { + match message { + Some(message) => { + let slot = message.context.slot; + if slot >= latest_slot { + latest_slot = slot; + let pubkey = message.value.pubkey; + let account_data = message.value.account.data; + match decode(account_data) { + Ok(data) => { + let data_and_slot = DataAndSlot:: { slot, data }; + event_emitter.emit(subscription_name, Box::new(ProgramAccountUpdate::new(pubkey, data_and_slot))); + }, + Err(e) => { + error!("Error decoding account data {e}"); + } + } } } + None => { + warn!("{} stream ended", subscription_name); + unsubscriber().await; + break; + } } } - None => { - warn!("{} stream ended", subscription_name); + _ = unsub_rx.recv() => { + debug!("Unsubscribing."); unsubscriber().await; - break; + return Ok(()); } } - } - _ = unsub_rx.recv() => { - debug!("Unsubscribing."); - unsubscriber().await; - break; + }, + Err(_) => { + error!("Failed to subscribe to program stream, retrying."); + attempt += 1; + if attempt >= max_reconnection_attempts { + error!("Max reconnection attempts reached."); + return Err(SdkError::MaxReconnectionAttemptsReached); + } } } + + if attempt >= max_reconnection_attempts { + error!("Max reconnection attempts reached."); + return Err(SdkError::MaxReconnectionAttemptsReached); + } + + let delay_duration = base_delay * 2_u32.pow(attempt); + debug!( + "{}: Reconnecting in {:?}", + subscription_name, delay_duration + ); + tokio::time::sleep(delay_duration).await; + attempt += 1; } } });