diff --git a/Cargo.lock b/Cargo.lock
index 6cf99ba..ddb4caa 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1120,6 +1120,7 @@ dependencies = [
"drift",
"env_logger 0.10.2",
"fnv",
+ "futures",
"futures-util",
"hex",
"hex-literal",
diff --git a/Cargo.toml b/Cargo.toml
index 1255a1e..3f83585 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -57,3 +57,4 @@ spl-associated-token-account = "1.1.2"
anchor-client = "0.27.0"
anchor-lang = "*"
bytes = "*"
+futures = "0.3.30"
diff --git a/README.md b/README.md
index 3995619..6b51c7b 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
-
Drift Protocol v2 (Rust SDK)
+
drift-rs
@@ -11,9 +11,9 @@
-# Drift Protocol v2 (Rust SDK)
+# drift-rs
-Rust SDK for building off chain clients for interacting with the [Drift V2](https://github.com/drift-labs/protocol-v2) protocol.
+Experimental, high performance Rust SDK for building off chain clients for interacting with the [Drift V2](https://github.com/drift-labs/protocol-v2) protocol.
## Setup
diff --git a/src/dlob/dlob.rs b/src/dlob/dlob.rs
index eda1154..ad55014 100644
--- a/src/dlob/dlob.rs
+++ b/src/dlob/dlob.rs
@@ -14,7 +14,7 @@ use crate::dlob::dlob_node::{
use crate::dlob::market::{get_order_lists, Exchange, Market, OpenOrders, SubType};
use crate::event_emitter::Event;
use crate::math::order::is_resting_limit_order;
-use crate::usermap::Usermap;
+use crate::usermap::UserMap;
use crate::utils::market_type_to_string;
#[derive(Clone)]
@@ -43,7 +43,7 @@ impl DLOB {
}
}
- pub fn build_from_usermap(&mut self, usermap: &Usermap, slot: u64) {
+ pub fn build_from_usermap(&mut self, usermap: &UserMap, slot: u64) {
usermap.usermap.iter().par_bridge().for_each(|user_ref| {
let user = user_ref.value();
let user_key = user_ref.key();
@@ -98,14 +98,14 @@ impl DLOB {
.get_mut(&market_index)
.expect(format!("Market index {} not found", market_index).as_str());
- let (order_list, subtype, node_type) = market.get_info_for_order_insert(&order, slot);
+ let (order_list, subtype, node_type) = market.get_info_for_order_insert(order, slot);
- let node = create_node(node_type, order.clone(), user_account);
+ let node = create_node(node_type, *order, user_account);
if let Some(order_list) = order_list {
match subtype {
- SubType::Bid => order_list.insert_bid(node.clone()),
- SubType::Ask => order_list.insert_ask(node.clone()),
+ SubType::Bid => order_list.insert_bid(node),
+ SubType::Ask => order_list.insert_ask(node),
_ => {}
}
} else {
@@ -117,7 +117,7 @@ impl DLOB {
let order_signature = get_order_signature(order_id, user_account);
for order_list in get_order_lists(&self.exchange) {
if let Some(node) = order_list.get_node(&order_signature) {
- return Some(node.get_order().clone());
+ return Some(*node.get_order());
}
}
diff --git a/src/dlob/dlob_builder.rs b/src/dlob/dlob_builder.rs
index 9143c0a..3104acf 100644
--- a/src/dlob/dlob_builder.rs
+++ b/src/dlob/dlob_builder.rs
@@ -1,14 +1,11 @@
use crate::{
- dlob::dlob::DLOB,
- event_emitter::{Event, EventEmitter},
- slot_subscriber::SlotSubscriber,
- usermap::Usermap,
- SdkResult,
+ dlob::dlob::DLOB, event_emitter::EventEmitter, slot_subscriber::SlotSubscriber,
+ usermap::UserMap, SdkResult,
};
pub struct DLOBBuilder {
slot_subscriber: SlotSubscriber,
- usermap: Usermap,
+ usermap: UserMap,
rebuild_frequency: u64,
dlob: DLOB,
event_emitter: EventEmitter,
@@ -17,7 +14,7 @@ pub struct DLOBBuilder {
impl DLOBBuilder {
pub fn new(
slot_subscriber: SlotSubscriber,
- usermap: Usermap,
+ usermap: UserMap,
rebuild_frequency: u64,
) -> DLOBBuilder {
DLOBBuilder {
@@ -56,6 +53,7 @@ impl DLOBBuilder {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::memcmp::get_user_with_order_filter;
use crate::utils::get_ws_url;
use solana_sdk::commitment_config::CommitmentConfig;
use solana_sdk::commitment_config::CommitmentLevel;
@@ -70,7 +68,12 @@ mod tests {
};
let slot_subscriber = SlotSubscriber::new(get_ws_url(&endpoint.clone()).unwrap());
- let usermap = Usermap::new(commitment, endpoint, true);
+ let mut usermap = UserMap::new(
+ commitment,
+ endpoint,
+ true,
+ Some(vec![get_user_with_order_filter()]),
+ );
let mut dlob_builder = DLOBBuilder::new(slot_subscriber, usermap, 30);
dlob_builder
@@ -90,13 +93,18 @@ mod tests {
#[tokio::test]
#[cfg(rpc_tests)]
async fn test_build_time() {
- let endpoint = "url".to_string();
+ let endpoint = "rpc".to_string();
let commitment = CommitmentConfig {
commitment: CommitmentLevel::Processed,
};
let mut slot_subscriber = SlotSubscriber::new(get_ws_url(&endpoint.clone()).unwrap());
- let mut usermap = Usermap::new(commitment, endpoint, true);
+ let mut usermap = UserMap::new(
+ commitment,
+ endpoint,
+ true,
+ Some(vec![get_user_with_order_filter()]),
+ );
let _ = slot_subscriber.subscribe().await;
let _ = usermap.subscribe().await;
diff --git a/src/dlob/dlob_node.rs b/src/dlob/dlob_node.rs
index a3219d8..8a2972e 100644
--- a/src/dlob/dlob_node.rs
+++ b/src/dlob/dlob_node.rs
@@ -298,7 +298,7 @@ pub(crate) fn create_node(node_type: NodeType, order: Order, user_account: Pubke
}
pub(crate) fn get_order_signature(order_id: u32, user_account: Pubkey) -> String {
- format!("{}-{}", order_id, user_account.to_string())
+ format!("{}-{}", order_id, user_account)
}
#[cfg(test)]
diff --git a/src/dlob/market.rs b/src/dlob/market.rs
index ce1b759..0d5c409 100644
--- a/src/dlob/market.rs
+++ b/src/dlob/market.rs
@@ -61,12 +61,10 @@ impl Market {
NodeType::Market
} else if order.oracle_price_offset != 0 {
NodeType::FloatingLimit
+ } else if is_resting_limit_order(order, slot) {
+ NodeType::RestingLimit
} else {
- if is_resting_limit_order(order, slot) {
- NodeType::RestingLimit
- } else {
- NodeType::TakingLimit
- }
+ NodeType::TakingLimit
};
let order_list = match node_type {
diff --git a/src/dlob/order_list.rs b/src/dlob/order_list.rs
index 921af38..5a5c48e 100644
--- a/src/dlob/order_list.rs
+++ b/src/dlob/order_list.rs
@@ -26,20 +26,20 @@ impl Orderlist {
pub fn insert_bid(&mut self, node: Node) {
let order_sig = get_order_signature(node.get_order().order_id, node.get_user_account());
- self.order_sigs.insert(order_sig.clone(), node.clone());
+ self.order_sigs.insert(order_sig.clone(), node);
let directional = DirectionalNode::new(node, self.bid_sort_direction);
self.bids.push(directional);
}
pub fn insert_ask(&mut self, node: Node) {
let order_sig = get_order_signature(node.get_order().order_id, node.get_user_account());
- self.order_sigs.insert(order_sig.clone(), node.clone());
+ self.order_sigs.insert(order_sig.clone(), node);
let directional = DirectionalNode::new(node, self.ask_sort_direction);
self.asks.push(directional);
}
pub fn get_best_bid(&mut self) -> Option {
- if let Some(node) = self.bids.pop().map(|node| node.node.clone()) {
+ if let Some(node) = self.bids.pop().map(|node| node.node) {
let order_sig = get_order_signature(node.get_order().order_id, node.get_user_account());
if self.order_sigs.contains_key(&order_sig) {
self.order_sigs.remove(&order_sig);
@@ -50,7 +50,7 @@ impl Orderlist {
}
pub fn get_best_ask(&mut self) -> Option {
- if let Some(node) = self.asks.pop().map(|node| node.node.clone()) {
+ if let Some(node) = self.asks.pop().map(|node| node.node) {
let order_sig = get_order_signature(node.get_order().order_id, node.get_user_account());
if self.order_sigs.contains_key(&order_sig) {
self.order_sigs.remove(&order_sig);
@@ -61,7 +61,7 @@ impl Orderlist {
}
pub fn get_node(&self, order_sig: &String) -> Option {
- self.order_sigs.get(order_sig).map(|node| node.clone())
+ self.order_sigs.get(order_sig).map(|node| *node)
}
pub fn bids_empty(&self) -> bool {
diff --git a/src/lib.rs b/src/lib.rs
index f4f22ea..d8e344e 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,6 +1,6 @@
//! Drift SDK
-use std::{borrow::Cow, sync::Arc, time::Duration};
+use std::{borrow::Cow, rc::Rc, sync::Arc, time::Duration};
use anchor_lang::{AccountDeserialize, Discriminator, InstructionData, ToAccountMetas};
use async_utils::{retry_policy, spawn_retry_task};
@@ -20,6 +20,8 @@ use drift::{
use fnv::FnvHashMap;
use futures_util::{future::BoxFuture, FutureExt, StreamExt};
use log::{debug, warn};
+use marketmap::MarketMap;
+use oraclemap::{OracleMap, OraclePriceDataAndSlot};
use solana_account_decoder::UiAccountEncoding;
use solana_client::{
nonblocking::{pubsub_client::PubsubClient, rpc_client::RpcClient},
@@ -54,6 +56,7 @@ use crate::constants::{
// utils
pub mod async_utils;
+pub mod math;
pub mod memcmp;
pub mod utils;
@@ -63,19 +66,21 @@ pub mod types;
// internal infra
pub mod event_emitter;
+pub mod websocket_account_subscriber;
pub mod websocket_program_account_subscriber;
// subscribers
pub mod auction_subscriber;
pub mod dlob_client;
pub mod event_subscriber;
+pub mod marketmap;
+pub mod oraclemap;
#[cfg(feature = "jit")]
pub mod jit_client;
pub mod slot_subscriber;
+pub mod usermap;
pub mod dlob;
-pub mod math;
-pub mod usermap;
use types::*;
@@ -322,6 +327,18 @@ impl DriftClient {
})
}
+ /// Subscribe to the Drift Client Backend
+ /// This is a no-op if already subscribed
+ pub async fn subscribe(&self) -> SdkResult<()> {
+ self.backend.subscribe().await
+ }
+
+ /// Unsubscribe from the Drift Client Backend
+ /// This is a no-op if not subscribed
+ pub async fn unsubscribe(&self) -> SdkResult<()> {
+ self.backend.unsubscribe().await
+ }
+
/// Return a handle to the inner RPC client
pub fn inner(&self) -> &RpcClient {
self.backend.client()
@@ -586,6 +603,63 @@ impl DriftClient {
.get_recent_priority_fees(writable_markets, window)
.await
}
+
+ pub fn get_perp_market_account_and_slot(
+ &self,
+ market_index: u16,
+ ) -> Option> {
+ self.backend.get_perp_market_account_and_slot(market_index)
+ }
+
+ pub fn get_spot_market_account_and_slot(
+ &self,
+ market_index: u16,
+ ) -> Option> {
+ self.backend.get_spot_market_account_and_slot(market_index)
+ }
+
+ pub fn get_perp_market_account(&self, market_index: u16) -> Option {
+ self.backend
+ .get_perp_market_account_and_slot(market_index)
+ .map(|x| x.data)
+ }
+
+ pub fn get_spot_market_account(&self, market_index: u16) -> Option {
+ self.backend
+ .get_spot_market_account_and_slot(market_index)
+ .map(|x| x.data)
+ }
+
+ pub fn num_perp_markets(&self) -> usize {
+ self.backend.num_perp_markets()
+ }
+
+ pub fn num_spot_markets(&self) -> usize {
+ self.backend.num_spot_markets()
+ }
+
+ pub fn get_oracle_price_data_and_slot(
+ &self,
+ oracle_pubkey: Pubkey,
+ ) -> Option {
+ self.backend.get_oracle_price_data_and_slot(oracle_pubkey)
+ }
+
+ pub fn get_oracle_price_data_and_slot_for_perp_market(
+ &self,
+ market_index: u16,
+ ) -> Option {
+ self.backend
+ .get_oracle_price_data_and_slot_for_perp_market(market_index)
+ }
+
+ pub fn get_oracle_price_data_and_slot_for_spot_market(
+ &self,
+ market_index: u16,
+ ) -> Option {
+ self.backend
+ .get_oracle_price_data_and_slot_for_spot_market(market_index)
+ }
}
/// Provides the heavy-lifting and network facing features of the SDK
@@ -594,6 +668,9 @@ pub struct DriftClientBackend {
rpc_client: RpcClient,
account_provider: T,
program_data: ProgramData,
+ perp_market_map: MarketMap,
+ spot_market_map: MarketMap,
+ oracle_map: Rc,
}
impl DriftClientBackend {
@@ -604,10 +681,37 @@ impl DriftClientBackend {
account_provider.commitment_config(),
);
+ let perp_market_map = MarketMap::::new(
+ account_provider.commitment_config(),
+ account_provider.endpoint(),
+ true,
+ );
+ let spot_market_map = MarketMap::::new(
+ account_provider.commitment_config(),
+ account_provider.endpoint(),
+ true,
+ );
+
+ tokio::try_join!(perp_market_map.sync(), spot_market_map.sync())?;
+
+ let perp_oracles = perp_market_map.oracles();
+ let spot_oracles = spot_market_map.oracles();
+
+ let oracle_map = OracleMap::new(
+ account_provider.commitment_config(),
+ account_provider.endpoint(),
+ true,
+ perp_oracles,
+ spot_oracles,
+ );
+
let mut this = Self {
rpc_client,
account_provider,
program_data: ProgramData::uninitialized(),
+ perp_market_map,
+ spot_market_map,
+ oracle_map: Rc::new(oracle_map),
};
let lookup_table_address = market_lookup_table(context);
@@ -626,6 +730,101 @@ impl DriftClientBackend {
Ok(this)
}
+ async fn subscribe(&self) -> SdkResult<()> {
+ tokio::try_join!(
+ self.perp_market_map.subscribe(),
+ self.spot_market_map.subscribe(),
+ self.oracle_map.subscribe()
+ )?;
+ Ok(())
+ }
+
+ async fn unsubscribe(&self) -> SdkResult<()> {
+ tokio::try_join!(
+ self.perp_market_map.unsubscribe(),
+ self.spot_market_map.unsubscribe(),
+ self.oracle_map.unsubscribe()
+ )?;
+ Ok(())
+ }
+
+ fn get_perp_market_account_and_slot(
+ &self,
+ market_index: u16,
+ ) -> Option> {
+ self.perp_market_map.get(&market_index)
+ }
+
+ fn get_spot_market_account_and_slot(
+ &self,
+ market_index: u16,
+ ) -> Option> {
+ self.spot_market_map.get(&market_index)
+ }
+
+ fn num_perp_markets(&self) -> usize {
+ self.perp_market_map.size()
+ }
+
+ fn num_spot_markets(&self) -> usize {
+ self.spot_market_map.size()
+ }
+
+ fn get_oracle_price_data_and_slot(
+ &self,
+ oracle_pubkey: Pubkey,
+ ) -> Option {
+ self.oracle_map.get(&oracle_pubkey.to_string())
+ }
+
+ fn get_oracle_price_data_and_slot_for_perp_market(
+ &self,
+ market_index: u16,
+ ) -> Option {
+ let market = self.get_perp_market_account_and_slot(market_index)?;
+
+ let oracle = market.data.amm.oracle;
+ let current_oracle = self
+ .oracle_map
+ .current_perp_oracle(market_index)
+ .expect("oracle");
+
+ if oracle != current_oracle {
+ let source = market.data.amm.oracle_source;
+ let clone = self.oracle_map.clone();
+ tokio::task::spawn_local(async move {
+ let _ = clone.add_oracle(oracle, source).await;
+ clone.update_perp_oracle(market_index, oracle)
+ });
+ }
+
+ self.get_oracle_price_data_and_slot(current_oracle)
+ }
+
+ fn get_oracle_price_data_and_slot_for_spot_market(
+ &self,
+ market_index: u16,
+ ) -> Option {
+ let market = self.get_spot_market_account_and_slot(market_index)?;
+
+ let oracle = market.data.oracle;
+ let current_oracle = self
+ .oracle_map
+ .current_spot_oracle(market_index)
+ .expect("oracle");
+
+ if oracle != current_oracle {
+ let source = market.data.oracle_source;
+ let clone = self.oracle_map.clone();
+ tokio::task::spawn_local(async move {
+ let _ = clone.add_oracle(oracle, source).await;
+ clone.update_spot_oracle(market_index, oracle);
+ });
+ }
+
+ self.get_oracle_price_data_and_slot(market.data.oracle)
+ }
+
/// Return a handle to the inner RPC client
fn client(&self) -> &RpcClient {
&self.rpc_client
@@ -1612,6 +1811,17 @@ mod tests {
account_provider_mocks: Mocks,
keypair: Keypair,
) -> DriftClient {
+ let perp_market_map = MarketMap::::new(
+ CommitmentConfig::processed(),
+ DEVNET_ENDPOINT.to_string(),
+ false,
+ );
+ let spot_market_map = MarketMap::::new(
+ CommitmentConfig::processed(),
+ DEVNET_ENDPOINT.to_string(),
+ false,
+ );
+
let backend = DriftClientBackend {
rpc_client: RpcClient::new_mock_with_mocks(DEVNET_ENDPOINT.to_string(), rpc_mocks),
account_provider: RpcAccountProvider {
@@ -1621,6 +1831,15 @@ mod tests {
)),
},
program_data: ProgramData::uninitialized(),
+ perp_market_map,
+ spot_market_map,
+ oracle_map: Rc::new(OracleMap::new(
+ CommitmentConfig::processed(),
+ DEVNET_ENDPOINT.to_string(),
+ true,
+ vec![],
+ vec![],
+ )),
};
DriftClient {
@@ -1631,6 +1850,38 @@ mod tests {
}
}
+ #[tokio::test]
+ #[cfg(rpc_tests)]
+ async fn test_marketmap_subscribe() {
+ let endpoint = "rpc";
+
+ let client = DriftClient::new(
+ Context::MainNet,
+ RpcAccountProvider::new(endpoint),
+ Keypair::new().into(),
+ )
+ .await
+ .unwrap();
+
+ let _ = client.subscribe().await;
+
+ tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
+
+ for _ in 0..20 {
+ tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
+ let perp_market = client.get_perp_market_account_and_slot(0);
+ let slot = perp_market.unwrap().slot;
+ dbg!(slot);
+ }
+
+ for _ in 0..20 {
+ tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
+ let spot_market = client.get_spot_market_account_and_slot(0);
+ let slot = spot_market.unwrap().slot;
+ dbg!(slot);
+ }
+ }
+
#[tokio::test]
async fn get_market_accounts() {
let client = DriftClient::new(
diff --git a/src/marketmap.rs b/src/marketmap.rs
new file mode 100644
index 0000000..9fc9034
--- /dev/null
+++ b/src/marketmap.rs
@@ -0,0 +1,282 @@
+use std::cell::{Cell, RefCell};
+use std::sync::atomic::{AtomicU64, Ordering};
+use std::sync::{Arc, Mutex};
+
+use crate::event_emitter::EventEmitter;
+use crate::memcmp::get_market_filter;
+use crate::utils::{decode, get_ws_url};
+use crate::websocket_program_account_subscriber::{
+ ProgramAccountUpdate, WebsocketProgramAccountOptions, WebsocketProgramAccountSubscriber,
+};
+use crate::{DataAndSlot, SdkResult};
+use anchor_lang::AccountDeserialize;
+use dashmap::DashMap;
+use drift::state::oracle::OracleSource;
+use drift::state::perp_market::PerpMarket;
+use drift::state::spot_market::SpotMarket;
+use drift::state::user::MarketType;
+use serde_json::json;
+use solana_account_decoder::UiAccountEncoding;
+use solana_client::nonblocking::rpc_client::RpcClient;
+use solana_client::rpc_config::{RpcAccountInfoConfig, RpcProgramAccountsConfig};
+use solana_client::rpc_request::RpcRequest;
+use solana_client::rpc_response::{OptionalContext, RpcKeyedAccount};
+use solana_sdk::commitment_config::CommitmentConfig;
+use solana_sdk::pubkey::Pubkey;
+
+pub trait Market {
+ const MARKET_TYPE: MarketType;
+ fn market_index(&self) -> u16;
+ fn oracle_info(&self) -> (u16, Pubkey, OracleSource);
+}
+
+impl Market for PerpMarket {
+ const MARKET_TYPE: MarketType = MarketType::Perp;
+
+ fn market_index(&self) -> u16 {
+ self.market_index
+ }
+
+ fn oracle_info(&self) -> (u16, Pubkey, OracleSource) {
+ (self.market_index(), self.amm.oracle, self.amm.oracle_source)
+ }
+}
+
+impl Market for SpotMarket {
+ const MARKET_TYPE: MarketType = MarketType::Spot;
+
+ fn market_index(&self) -> u16 {
+ self.market_index
+ }
+
+ fn oracle_info(&self) -> (u16, Pubkey, OracleSource) {
+ (self.market_index(), self.oracle, self.oracle_source)
+ }
+}
+
+pub struct MarketMap {
+ subscribed: Cell,
+ subscription: RefCell,
+ marketmap: Arc>>,
+ sync_lock: Option>,
+ latest_slot: Arc,
+ commitment: CommitmentConfig,
+ rpc: RpcClient,
+ synced: bool,
+}
+
+impl MarketMap {
+ pub fn new(commitment: CommitmentConfig, endpoint: String, sync: bool) -> Self {
+ let filters = vec![get_market_filter(T::MARKET_TYPE)];
+ let options = WebsocketProgramAccountOptions {
+ filters,
+ commitment,
+ encoding: UiAccountEncoding::Base64,
+ };
+ let event_emitter = EventEmitter::new();
+
+ let url = get_ws_url(&endpoint.clone()).unwrap();
+
+ let subscription =
+ WebsocketProgramAccountSubscriber::new("marketmap", url, options, event_emitter);
+
+ let marketmap = Arc::new(DashMap::new());
+
+ let rpc = RpcClient::new_with_commitment(endpoint.clone(), commitment);
+
+ let sync_lock = if sync { Some(Mutex::new(())) } else { None };
+
+ Self {
+ subscribed: Cell::new(false),
+ subscription: RefCell::new(subscription),
+ marketmap,
+ sync_lock,
+ latest_slot: Arc::new(AtomicU64::new(0)),
+ commitment,
+ rpc,
+ synced: false,
+ }
+ }
+
+ pub async fn subscribe(&self) -> SdkResult<()> {
+ if self.sync_lock.is_some() {
+ self.sync().await?;
+ }
+
+ if !self.subscribed.get() {
+ self.subscription.try_borrow_mut()?.subscribe::().await?;
+ self.subscribed.set(true);
+
+ let marketmap = self.marketmap.clone();
+ let latest_slot = self.latest_slot.clone();
+
+ self.subscription
+ .try_borrow()?
+ .event_emitter
+ .subscribe("marketmap", move |event| {
+ if let Some(update) = event.as_any().downcast_ref::>() {
+ let market_data_and_slot = update.data_and_slot.clone();
+ if update.data_and_slot.slot > latest_slot.load(Ordering::Relaxed) {
+ latest_slot.store(update.data_and_slot.slot, Ordering::Relaxed);
+ }
+ marketmap.insert(
+ update.data_and_slot.clone().data.market_index(),
+ DataAndSlot {
+ data: market_data_and_slot.data,
+ slot: update.data_and_slot.slot,
+ },
+ );
+ }
+ });
+ }
+ Ok(())
+ }
+
+ pub async fn unsubscribe(&self) -> SdkResult<()> {
+ if self.subscribed.get() {
+ self.subscription.try_borrow_mut()?.unsubscribe().await?;
+ self.subscribed.set(false);
+ self.marketmap.clear();
+ self.latest_slot.store(0, Ordering::Relaxed);
+ }
+ Ok(())
+ }
+
+ pub fn values(&self) -> Vec {
+ self.marketmap.iter().map(|x| x.data.clone()).collect()
+ }
+
+ pub fn oracles(&self) -> Vec<(u16, Pubkey, OracleSource)> {
+ self.values().iter().map(|x| x.oracle_info()).collect()
+ }
+
+ pub fn size(&self) -> usize {
+ self.marketmap.len()
+ }
+
+ pub fn contains(&self, market_index: &u16) -> bool {
+ self.marketmap.contains_key(market_index)
+ }
+
+ pub fn get(&self, market_index: &u16) -> Option> {
+ self.marketmap
+ .get(market_index)
+ .map(|market| market.clone())
+ }
+
+ pub(crate) async fn sync(&self) -> SdkResult<()> {
+ if self.synced {
+ return Ok(());
+ }
+
+ let sync_lock = self.sync_lock.as_ref().expect("expected sync lock");
+
+ let lock = match sync_lock.try_lock() {
+ Ok(lock) => lock,
+ Err(_) => return Ok(()),
+ };
+
+ let options = self.subscription.try_borrow()?.options.clone();
+
+ let account_config = RpcAccountInfoConfig {
+ commitment: Some(self.commitment),
+ encoding: Some(options.encoding),
+ ..RpcAccountInfoConfig::default()
+ };
+
+ let gpa_config = RpcProgramAccountsConfig {
+ filters: Some(options.filters),
+ account_config,
+ with_context: Some(true),
+ };
+
+ let response = self
+ .rpc
+ .send::>>(
+ RpcRequest::GetProgramAccounts,
+ json!([drift::id().to_string(), gpa_config]),
+ )
+ .await?;
+
+ if let OptionalContext::Context(accounts) = response {
+ for account in accounts.value {
+ let slot = accounts.context.slot;
+ let market_data = account.account.data;
+ let data = decode::(market_data)?;
+ self.marketmap
+ .insert(data.market_index(), DataAndSlot { data, slot });
+ }
+
+ self.latest_slot
+ .store(accounts.context.slot, Ordering::Relaxed);
+ }
+
+ drop(lock);
+ Ok(())
+ }
+
+ pub fn get_latest_slot(&self) -> u64 {
+ self.latest_slot.load(Ordering::Relaxed)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::marketmap::MarketMap;
+ use drift::state::perp_market::PerpMarket;
+ use drift::state::spot_market::SpotMarket;
+ use solana_sdk::commitment_config::CommitmentConfig;
+ use solana_sdk::commitment_config::CommitmentLevel;
+
+ #[tokio::test]
+ #[cfg(rpc_tests)]
+ async fn test_marketmap_perp() {
+ let endpoint = "rpc".to_string();
+ let commitment = CommitmentConfig {
+ commitment: CommitmentLevel::Processed,
+ };
+
+ let marketmap = MarketMap::::new(commitment, endpoint, true);
+ marketmap.subscribe().await.unwrap();
+
+ tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
+
+ dbg!(marketmap.size());
+ assert!(marketmap.size() == 28);
+
+ dbg!(marketmap.get_latest_slot());
+
+ marketmap.unsubscribe().await.unwrap();
+
+ tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
+
+ assert_eq!(marketmap.size(), 0);
+ assert_eq!(marketmap.subscribed.get(), false);
+ }
+
+ #[tokio::test]
+ #[cfg(rpc_tests)]
+ async fn test_marketmap_spot() {
+ let endpoint = "rpc".to_string();
+ let commitment = CommitmentConfig {
+ commitment: CommitmentLevel::Processed,
+ };
+
+ let marketmap = MarketMap::::new(commitment, endpoint, true);
+ marketmap.subscribe().await.unwrap();
+
+ tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
+
+ dbg!(marketmap.size());
+ assert!(marketmap.size() == 13);
+
+ dbg!(marketmap.get_latest_slot());
+
+ marketmap.unsubscribe().await.unwrap();
+
+ tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
+
+ assert_eq!(marketmap.size(), 0);
+ assert_eq!(marketmap.subscribed.get(), false);
+ }
+}
diff --git a/src/memcmp.rs b/src/memcmp.rs
index 958f670..a89ed11 100644
--- a/src/memcmp.rs
+++ b/src/memcmp.rs
@@ -1,5 +1,7 @@
use anchor_lang::Discriminator;
-use drift::state::user::User;
+use drift::state::perp_market::PerpMarket;
+use drift::state::spot_market::SpotMarket;
+use drift::state::user::{MarketType, User};
use solana_client::rpc_filter::{Memcmp, RpcFilterType};
pub fn get_user_filter() -> RpcFilterType {
@@ -7,7 +9,7 @@ pub fn get_user_filter() -> RpcFilterType {
}
pub fn get_non_idle_user_filter() -> RpcFilterType {
- RpcFilterType::Memcmp(Memcmp::new_raw_bytes(4_350, vec![1]))
+ RpcFilterType::Memcmp(Memcmp::new_raw_bytes(4_350, vec![0]))
}
pub fn get_user_with_auction_filter() -> RpcFilterType {
@@ -17,3 +19,14 @@ pub fn get_user_with_auction_filter() -> RpcFilterType {
pub fn get_user_with_order_filter() -> RpcFilterType {
RpcFilterType::Memcmp(Memcmp::new_raw_bytes(4_352, vec![1]))
}
+
+pub fn get_market_filter(market_type: MarketType) -> RpcFilterType {
+ match market_type {
+ MarketType::Spot => {
+ RpcFilterType::Memcmp(Memcmp::new_raw_bytes(0, SpotMarket::discriminator().into()))
+ }
+ MarketType::Perp => {
+ RpcFilterType::Memcmp(Memcmp::new_raw_bytes(0, PerpMarket::discriminator().into()))
+ }
+ }
+}
diff --git a/src/oraclemap.rs b/src/oraclemap.rs
new file mode 100644
index 0000000..7d01f35
--- /dev/null
+++ b/src/oraclemap.rs
@@ -0,0 +1,465 @@
+use crate::utils::get_ws_url;
+use crate::websocket_account_subscriber::{AccountUpdate, WebsocketAccountSubscriber};
+use crate::{event_emitter::EventEmitter, SdkResult};
+use dashmap::DashMap;
+use drift::state::oracle::{get_oracle_price, OraclePriceData, OracleSource};
+use solana_account_decoder::{UiAccountData, UiAccountEncoding};
+use solana_client::nonblocking::rpc_client::RpcClient;
+use solana_client::rpc_config::RpcAccountInfoConfig;
+use solana_sdk::account_info::{AccountInfo, IntoAccountInfo};
+use solana_sdk::{commitment_config::CommitmentConfig, pubkey::Pubkey};
+use std::cell::{Cell, RefCell};
+use std::str::FromStr;
+use std::sync::atomic::{AtomicU64, Ordering};
+use std::sync::{Arc, Mutex};
+
+#[derive(Copy, Clone, Debug)]
+pub struct OraclePriceDataAndSlot {
+ pub data: OraclePriceData,
+ pub slot: u64,
+}
+
+pub(crate) struct OracleMap {
+ subscribed: Cell,
+ pub(crate) oraclemap: Arc>,
+ event_emitter: &'static EventEmitter,
+ oracle_infos: DashMap,
+ sync_lock: Option>,
+ latest_slot: Arc,
+ commitment: CommitmentConfig,
+ rpc: RpcClient,
+ oracle_subscribers: RefCell>,
+ perp_oracles: DashMap,
+ spot_oracles: DashMap,
+}
+
+impl OracleMap {
+ pub fn new(
+ commitment: CommitmentConfig,
+ endpoint: String,
+ sync: bool,
+ perp_oracles: Vec<(u16, Pubkey, OracleSource)>,
+ spot_oracles: Vec<(u16, Pubkey, OracleSource)>,
+ ) -> Self {
+ let oraclemap = Arc::new(DashMap::new());
+
+ let event_emitter = EventEmitter::new();
+
+ let rpc = RpcClient::new_with_commitment(endpoint.clone(), commitment);
+
+ let sync_lock = if sync { Some(Mutex::new(())) } else { None };
+
+ let mut all_oracles = vec![];
+ all_oracles.extend(perp_oracles.clone());
+ all_oracles.extend(spot_oracles.clone());
+
+ let oracle_infos_map: DashMap<_, _> = all_oracles
+ .iter()
+ .map(|(_, pubkey, oracle_source)| (*pubkey, *oracle_source))
+ .collect();
+
+ let perp_oracles_map: DashMap<_, _> = perp_oracles
+ .iter()
+ .map(|(market_index, pubkey, _)| (*market_index, *pubkey))
+ .collect();
+
+ let spot_oracles_map: DashMap<_, _> = spot_oracles
+ .iter()
+ .map(|(market_index, pubkey, _)| (*market_index, *pubkey))
+ .collect();
+
+ Self {
+ subscribed: Cell::new(false),
+ oraclemap,
+ oracle_infos: oracle_infos_map,
+ sync_lock,
+ latest_slot: Arc::new(AtomicU64::new(0)),
+ commitment,
+ event_emitter: Box::leak(Box::new(event_emitter)),
+ rpc,
+ oracle_subscribers: RefCell::new(vec![]),
+ perp_oracles: perp_oracles_map,
+ spot_oracles: spot_oracles_map,
+ }
+ }
+
+ pub async fn subscribe(&self) -> SdkResult<()> {
+ if self.sync_lock.is_some() {
+ self.sync().await?;
+ }
+
+ if !self.subscribed.get() {
+ let url = get_ws_url(&self.rpc.url()).expect("valid url");
+ let subscription_name: &'static str = "oraclemap";
+
+ let mut oracle_subscribers = vec![];
+ for oracle_info in self.oracle_infos.iter() {
+ let oracle_pubkey = oracle_info.key();
+ let oracle_subscriber = WebsocketAccountSubscriber::new(
+ subscription_name,
+ url.clone(),
+ *oracle_pubkey,
+ self.commitment,
+ self.event_emitter.clone(),
+ );
+ oracle_subscribers.push(oracle_subscriber);
+ }
+
+ self.subscribed.set(true);
+
+ let oracle_source_by_oracle_key = self.oracle_infos.clone();
+ let oracle_map = self.oraclemap.clone();
+
+ self.event_emitter.subscribe("oraclemap", move |event| {
+ if let Some(update) = event.as_any().downcast_ref::() {
+ let oracle_pubkey = Pubkey::from_str(&update.pubkey).expect("valid pubkey");
+ let oracle_source_maybe = oracle_source_by_oracle_key.get(&oracle_pubkey);
+ if let Some(oracle_source) = oracle_source_maybe {
+ if let UiAccountData::Binary(blob, UiAccountEncoding::Base64) =
+ &update.data.data
+ {
+ let mut data = base64::decode(blob).expect("valid data");
+ let owner = Pubkey::from_str(&update.data.owner).expect("valid pubkey");
+ let mut lamports = update.data.lamports;
+ let oracle_account_info = AccountInfo::new(
+ &oracle_pubkey,
+ false,
+ false,
+ &mut lamports,
+ &mut data,
+ &owner,
+ false,
+ update.data.rent_epoch,
+ );
+ match get_oracle_price(
+ oracle_source.value(),
+ &oracle_account_info,
+ update.slot,
+ ) {
+ Ok(price_data) => {
+ oracle_map.insert(
+ update.pubkey.clone(),
+ OraclePriceDataAndSlot {
+ data: price_data,
+ slot: update.slot,
+ },
+ );
+ }
+ Err(err) => {
+ log::error!("Failed to get oracle price: {:?}", err)
+ }
+ }
+ }
+ }
+ }
+ });
+
+ let mut subscribers_clone = oracle_subscribers.clone();
+
+ let subscribe_futures = subscribers_clone
+ .iter_mut()
+ .map(|subscriber| subscriber.subscribe())
+ .collect::>();
+ let results = futures_util::future::join_all(subscribe_futures).await;
+ results.into_iter().collect::, _>>()?;
+
+ let mut oracle_subscribers_mut = self.oracle_subscribers.try_borrow_mut()?;
+ *oracle_subscribers_mut = oracle_subscribers;
+ }
+
+ Ok(())
+ }
+
+ pub async fn unsubscribe(&self) -> SdkResult<()> {
+ if self.subscribed.get() {
+ let mut oracle_subscribers = self.oracle_subscribers.try_borrow_mut()?;
+ let unsubscribe_futures = oracle_subscribers
+ .iter_mut()
+ .map(|subscriber| subscriber.unsubscribe())
+ .collect::>();
+
+ let results = futures_util::future::join_all(unsubscribe_futures).await;
+ results.into_iter().collect::, _>>()?;
+ self.subscribed.set(false);
+ self.oraclemap.clear();
+ self.latest_slot.store(0, Ordering::Relaxed);
+ }
+ Ok(())
+ }
+
+ async fn sync(&self) -> SdkResult<()> {
+ let sync_lock = self.sync_lock.as_ref().expect("expected sync lock");
+
+ let lock = match sync_lock.try_lock() {
+ Ok(lock) => lock,
+ Err(_) => return Ok(()),
+ };
+
+ let account_config = RpcAccountInfoConfig {
+ commitment: Some(self.commitment),
+ encoding: None,
+ ..RpcAccountInfoConfig::default()
+ };
+
+ let mut pubkeys = self
+ .oracle_infos
+ .iter()
+ .map(|oracle_info_ref| *oracle_info_ref.key())
+ .collect::>();
+ pubkeys.sort();
+
+ let mut oracle_infos = self
+ .oracle_infos
+ .iter()
+ .map(|oracle_info_ref| (*oracle_info_ref.key(), *oracle_info_ref.value()))
+ .collect::>();
+ oracle_infos.sort_by_key(|key| key.0);
+
+ let response = self
+ .rpc
+ .get_multiple_accounts_with_config(&pubkeys, account_config)
+ .await?;
+
+ if response.value.len() != pubkeys.len() {
+ return Err(crate::SdkError::Generic(format!(
+ "failed to get all oracle accounts, expected: {}, got: {}",
+ pubkeys.len(),
+ response.value.len()
+ )));
+ }
+
+ let slot = response.context.slot;
+
+ for (account, oracle_info) in response.value.iter().zip(oracle_infos.iter()) {
+ if let Some(oracle_account) = account {
+ let oracle_pubkey = oracle_info.0;
+ let mut oracle_components = (oracle_pubkey, oracle_account.clone());
+ let account_info = oracle_components.into_account_info();
+ let price_data = get_oracle_price(&oracle_info.1, &account_info, slot)
+ .map_err(|err| crate::SdkError::Anchor(Box::new(err.into())))?;
+ self.oraclemap.insert(
+ oracle_pubkey.to_string(),
+ OraclePriceDataAndSlot {
+ data: price_data,
+ slot,
+ },
+ );
+ }
+ }
+
+ self.latest_slot.store(slot, Ordering::Relaxed);
+
+ drop(lock);
+
+ Ok(())
+ }
+
+ pub fn size(&self) -> usize {
+ self.oraclemap.len()
+ }
+
+ pub fn contains(&self, key: &Pubkey) -> bool {
+ self.oracle_infos.contains_key(key)
+ }
+
+ pub fn current_perp_oracle(&self, market_index: u16) -> Option {
+ self.perp_oracles.get(&market_index).map(|x| *x)
+ }
+
+ pub fn current_spot_oracle(&self, market_index: u16) -> Option {
+ self.spot_oracles.get(&market_index).map(|x| *x)
+ }
+
+ pub fn get(&self, key: &str) -> Option {
+ self.oraclemap.get(key).map(|v| *v)
+ }
+
+ pub fn values(&self) -> Vec {
+ self.oraclemap.iter().map(|x| x.value().data).collect()
+ }
+
+ pub async fn add_oracle(&self, oracle: Pubkey, source: OracleSource) -> SdkResult<()> {
+ if self.contains(&oracle) {
+ return Ok(()); // don't add a duplicate
+ }
+
+ self.oracle_infos.insert(oracle, source);
+
+ let mut new_oracle_subscriber = WebsocketAccountSubscriber::new(
+ "oraclemap",
+ get_ws_url(&self.rpc.url()).expect("valid url"),
+ oracle,
+ self.commitment,
+ self.event_emitter.clone(),
+ );
+
+ new_oracle_subscriber.subscribe().await?;
+ let mut oracle_subscribers = self.oracle_subscribers.try_borrow_mut()?;
+ oracle_subscribers.push(new_oracle_subscriber);
+
+ Ok(())
+ }
+
+ pub fn update_spot_oracle(&self, market_index: u16, oracle: Pubkey) {
+ self.spot_oracles.insert(market_index, oracle);
+ }
+
+ pub fn update_perp_oracle(&self, market_index: u16, oracle: Pubkey) {
+ self.perp_oracles.insert(market_index, oracle);
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::marketmap::MarketMap;
+ use drift::state::perp_market::PerpMarket;
+ use drift::state::spot_market::SpotMarket;
+
+ #[tokio::test]
+ #[cfg(rpc_tests)]
+ async fn test_oracle_map() {
+ let commitment = CommitmentConfig::processed();
+ let endpoint = "rpc".to_string();
+
+ let spot_market_map =
+ MarketMap::::new(commitment.clone(), endpoint.clone(), true);
+ let perp_market_map =
+ MarketMap::::new(commitment.clone(), endpoint.clone(), true);
+
+ let _ = spot_market_map.sync().await;
+ let _ = perp_market_map.sync().await;
+
+ let perp_oracles = perp_market_map.oracles();
+ let spot_oracles = spot_market_map.oracles();
+
+ let mut oracles = vec![];
+ oracles.extend(perp_oracles.clone());
+ oracles.extend(spot_oracles.clone());
+
+ let mut oracle_infos = vec![];
+ for oracle_info in oracles {
+ if !oracle_infos.contains(&oracle_info) {
+ oracle_infos.push(oracle_info)
+ }
+ }
+
+ let oracle_infos_len = oracle_infos.len();
+ dbg!(oracle_infos_len);
+
+ let oracle_map = OracleMap::new(commitment, endpoint, true, perp_oracles, spot_oracles);
+
+ let _ = oracle_map.subscribe().await;
+
+ dbg!(oracle_map.size());
+ // assert_eq!(oracle_map.size(), oracle_infos_len);
+
+ dbg!("sleeping");
+ tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
+ dbg!("done sleeping");
+
+ dbg!("perp market oracles");
+ let mut last_sol_price = 0;
+ let mut last_sol_slot = 0;
+ let mut last_btc_price = 0;
+ let mut last_btc_slot = 0;
+ for _ in 0..10 {
+ tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
+ dbg!();
+ let sol_perp_market_oracle_pubkey = perp_market_map
+ .get(&0)
+ .expect("sol perp market")
+ .data
+ .amm
+ .oracle;
+ let sol_oracle = oracle_map
+ .get(&sol_perp_market_oracle_pubkey.to_string())
+ .expect("sol oracle");
+ dbg!("sol oracle info:");
+ dbg!(sol_oracle.data.price);
+ dbg!(sol_oracle.slot);
+ dbg!(
+ "sol price change: {}",
+ sol_oracle.data.price - last_sol_price
+ );
+ dbg!("sol slot change: {}", sol_oracle.slot - last_sol_slot);
+ last_sol_price = sol_oracle.data.price;
+ last_sol_slot = sol_oracle.slot;
+
+ dbg!();
+
+ let btc_perp_market_oracle_pubkey = perp_market_map
+ .get(&1)
+ .expect("btc perp market")
+ .data
+ .amm
+ .oracle;
+ let btc_oracle = oracle_map
+ .get(&btc_perp_market_oracle_pubkey.to_string())
+ .expect("btc oracle");
+ dbg!("btc oracle info:");
+ dbg!(btc_oracle.data.price);
+ dbg!(btc_oracle.slot);
+ dbg!(
+ "btc price change: {}",
+ btc_oracle.data.price - last_btc_price
+ );
+ dbg!("btc slot change: {}", btc_oracle.slot - last_btc_slot);
+ last_btc_price = btc_oracle.data.price;
+ last_btc_slot = btc_oracle.slot;
+ }
+
+ dbg!();
+
+ dbg!("spot market oracles");
+ let mut last_rndr_price = 0;
+ let mut last_rndr_slot = 0;
+ let mut last_weth_price = 0;
+ let mut last_weth_slot = 0;
+ for _ in 0..10 {
+ tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
+ dbg!();
+ let rndr_spot_market_oracle_pubkey = spot_market_map
+ .get(&11)
+ .expect("sol perp market")
+ .data
+ .oracle;
+ let rndr_oracle = oracle_map
+ .get(&rndr_spot_market_oracle_pubkey.to_string())
+ .expect("sol oracle");
+ dbg!("rndr oracle info:");
+ dbg!(rndr_oracle.data.price);
+ dbg!(rndr_oracle.slot);
+ dbg!(
+ "rndr price change: {}",
+ rndr_oracle.data.price - last_rndr_price
+ );
+ dbg!("rndr slot change: {}", rndr_oracle.slot - last_rndr_slot);
+ last_rndr_price = rndr_oracle.data.price;
+ last_rndr_slot = rndr_oracle.slot;
+
+ dbg!();
+
+ let weth_spot_market_oracle_pubkey = spot_market_map
+ .get(&4)
+ .expect("sol perp market")
+ .data
+ .oracle;
+ let weth_oracle = oracle_map
+ .get(&weth_spot_market_oracle_pubkey.to_string())
+ .expect("sol oracle");
+ dbg!("weth oracle info:");
+ dbg!(weth_oracle.data.price);
+ dbg!(weth_oracle.slot);
+ dbg!(
+ "weth price change: {}",
+ weth_oracle.data.price - last_weth_price
+ );
+ dbg!("weth slot change: {}", weth_oracle.slot - last_weth_slot);
+ last_weth_price = weth_oracle.data.price;
+ last_weth_slot = weth_oracle.slot;
+ }
+
+ let _ = oracle_map.unsubscribe().await;
+ }
+}
diff --git a/src/types.rs b/src/types.rs
index 45f5783..f410ea2 100644
--- a/src/types.rs
+++ b/src/types.rs
@@ -1,4 +1,7 @@
-use std::cmp::Ordering;
+use std::{
+ cell::{BorrowError, BorrowMutError},
+ cmp::Ordering,
+};
use anchor_lang::AccountDeserialize;
use drift::{error::ErrorCode, state::user::UserStats};
@@ -240,6 +243,14 @@ pub enum SdkError {
CouldntUnsubscribe(#[from] tokio::sync::mpsc::error::SendError<()>),
#[error("MathError")]
MathError(String),
+ #[error("{0}")]
+ BorrowMutError(#[from] BorrowMutError),
+ #[error("{0}")]
+ BorrowError(#[from] BorrowError),
+ #[error("{0}")]
+ Generic(String),
+ #[error("max connection attempts reached")]
+ MaxReconnectionAttemptsReached,
}
impl SdkError {
diff --git a/src/usermap.rs b/src/usermap.rs
index ab93985..1677408 100644
--- a/src/usermap.rs
+++ b/src/usermap.rs
@@ -16,12 +16,13 @@ use serde_json::json;
use solana_account_decoder::UiAccountEncoding;
use solana_client::nonblocking::rpc_client::RpcClient;
use solana_client::rpc_config::{RpcAccountInfoConfig, RpcProgramAccountsConfig};
+use solana_client::rpc_filter::RpcFilterType;
use solana_client::rpc_request::RpcRequest;
use solana_client::rpc_response::{OptionalContext, RpcKeyedAccount};
use solana_sdk::commitment_config::CommitmentConfig;
use solana_sdk::pubkey::Pubkey;
-pub struct Usermap {
+pub struct UserMap {
subscribed: bool,
subscription: WebsocketProgramAccountSubscriber,
pub(crate) usermap: Arc>,
@@ -31,9 +32,15 @@ pub struct Usermap {
rpc: RpcClient,
}
-impl Usermap {
- pub fn new(commitment: CommitmentConfig, endpoint: String, sync: bool) -> Self {
- let filters = vec![get_user_filter(), get_non_idle_user_filter()];
+impl UserMap {
+ pub fn new(
+ commitment: CommitmentConfig,
+ endpoint: String,
+ sync: bool,
+ additional_filters: Option>,
+ ) -> Self {
+ let mut filters = vec![get_user_filter(), get_non_idle_user_filter()];
+ filters.extend(additional_filters.unwrap_or_default());
let options = WebsocketProgramAccountOptions {
filters,
commitment,
@@ -64,7 +71,7 @@ impl Usermap {
}
pub async fn subscribe(&mut self) -> SdkResult<()> {
- if let Some(_) = self.sync_lock {
+ if self.sync_lock.is_some() {
self.sync().await?;
}
@@ -123,7 +130,7 @@ impl Usermap {
.get_account_data(&Pubkey::from_str(pubkey).unwrap())
.await?;
let user = User::try_deserialize(&mut user_data.as_slice()).unwrap();
- self.usermap.insert(pubkey.to_string(), user.clone());
+ self.usermap.insert(pubkey.to_string(), user);
Ok(self.get(pubkey).unwrap())
}
}
@@ -183,7 +190,7 @@ mod tests {
#[tokio::test]
#[cfg(rpc_tests)]
async fn test_usermap() {
- use crate::usermap::Usermap;
+ use crate::usermap::UserMap;
use solana_sdk::commitment_config::CommitmentConfig;
use solana_sdk::commitment_config::CommitmentLevel;
@@ -192,7 +199,7 @@ mod tests {
commitment: CommitmentLevel::Processed,
};
- let mut usermap = Usermap::new(commitment, endpoint, true);
+ let mut usermap = UserMap::new(commitment, endpoint, true);
usermap.subscribe().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
diff --git a/src/websocket_account_subscriber.rs b/src/websocket_account_subscriber.rs
new file mode 100644
index 0000000..65ade3f
--- /dev/null
+++ b/src/websocket_account_subscriber.rs
@@ -0,0 +1,171 @@
+use futures_util::StreamExt;
+use solana_account_decoder::{UiAccount, UiAccountEncoding};
+use solana_client::{nonblocking::pubsub_client::PubsubClient, rpc_config::RpcAccountInfoConfig};
+use solana_sdk::{commitment_config::CommitmentConfig, pubkey::Pubkey};
+
+use crate::{
+ event_emitter::{Event, EventEmitter},
+ SdkResult,
+};
+
+#[derive(Clone, Debug)]
+pub(crate) struct AccountUpdate {
+ pub pubkey: String,
+ pub data: UiAccount,
+ pub slot: u64,
+}
+
+impl Event for AccountUpdate {
+ fn box_clone(&self) -> Box {
+ Box::new((*self).clone())
+ }
+
+ fn as_any(&self) -> &dyn std::any::Any {
+ self
+ }
+}
+
+#[derive(Clone)]
+pub struct WebsocketAccountSubscriber {
+ subscription_name: &'static str,
+ url: String,
+ pubkey: Pubkey,
+ pub(crate) commitment: CommitmentConfig,
+ pub subscribed: bool,
+ pub event_emitter: EventEmitter,
+ unsubscriber: Option>,
+}
+
+impl WebsocketAccountSubscriber {
+ pub fn new(
+ subscription_name: &'static str,
+ url: String,
+ pubkey: Pubkey,
+ commitment: CommitmentConfig,
+ event_emitter: EventEmitter,
+ ) -> Self {
+ WebsocketAccountSubscriber {
+ subscription_name,
+ url,
+ pubkey,
+ commitment,
+ subscribed: false,
+ event_emitter,
+ unsubscriber: None,
+ }
+ }
+
+ pub async fn subscribe(&mut self) -> SdkResult<()> {
+ if self.subscribed {
+ return Ok(());
+ }
+
+ self.subscribed = true;
+ self.subscribe_ws().await?;
+ Ok(())
+ }
+
+ async fn subscribe_ws(&mut self) -> SdkResult<()> {
+ let account_config = RpcAccountInfoConfig {
+ commitment: Some(self.commitment),
+ encoding: Some(UiAccountEncoding::Base64),
+ ..RpcAccountInfoConfig::default()
+ };
+ let (unsub_tx, mut unsub_rx) = tokio::sync::mpsc::channel::<()>(1);
+ self.unsubscriber = Some(unsub_tx);
+
+ let mut attempt = 0;
+ let max_reconnection_attempts = 20;
+ let base_delay = tokio::time::Duration::from_secs(2);
+
+ let url = self.url.clone();
+
+ tokio::spawn({
+ let event_emitter = self.event_emitter.clone();
+ let mut latest_slot = 0;
+ let subscription_name = self.subscription_name;
+ let pubkey = self.pubkey.clone();
+ async move {
+ loop {
+ let pubsub = PubsubClient::new(&url).await?;
+
+ match pubsub
+ .account_subscribe(&pubkey, Some(account_config.clone()))
+ .await
+ {
+ Ok((mut account_updates, account_unsubscribe)) => loop {
+ tokio::select! {
+ message = account_updates.next() => {
+ match message {
+ Some(message) => {
+ let slot = message.context.slot;
+ if slot >= latest_slot {
+ latest_slot = slot;
+ let account_update = AccountUpdate {
+ pubkey: pubkey.to_string(),
+ data: message.value,
+ slot,
+ };
+ event_emitter.emit(subscription_name, Box::new(account_update));
+ }
+ }
+ None => {
+ log::warn!("{}: Account stream interrupted", subscription_name);
+ account_unsubscribe().await;
+ break;
+ }
+ }
+ }
+ unsub = unsub_rx.recv() => {
+ if let Some(_) = unsub {
+ log::debug!("{}: Unsubscribing from account stream", subscription_name);
+ account_unsubscribe().await;
+ return Ok(());
+
+ }
+ }
+ }
+ },
+ Err(_) => {
+ log::error!(
+ "{}: Failed to subscribe to account stream, retrying",
+ subscription_name
+ );
+ attempt += 1;
+ if attempt >= max_reconnection_attempts {
+ log::error!("Max reconnection attempts reached.");
+ return Err(crate::SdkError::MaxReconnectionAttemptsReached);
+ }
+ }
+ }
+
+ if attempt >= max_reconnection_attempts {
+ log::error!("{}: Max reconnection attempts reached", subscription_name);
+ return Err(crate::SdkError::MaxReconnectionAttemptsReached);
+ }
+
+ let delay_duration = base_delay * 2_u32.pow(attempt);
+ log::debug!(
+ "{}: Reconnecting in {:?}",
+ subscription_name,
+ delay_duration
+ );
+ tokio::time::sleep(delay_duration).await;
+ attempt += 1;
+ }
+ }
+ });
+ Ok(())
+ }
+
+ pub async fn unsubscribe(&mut self) -> SdkResult<()> {
+ if self.subscribed && self.unsubscriber.is_some() {
+ if let Err(e) = self.unsubscriber.as_ref().unwrap().send(()).await {
+ log::error!("Failed to send unsubscribe signal: {:?}", e);
+ return Err(crate::SdkError::CouldntUnsubscribe(e));
+ }
+ self.subscribed = false;
+ }
+ Ok(())
+ }
+}
diff --git a/src/websocket_program_account_subscriber.rs b/src/websocket_program_account_subscriber.rs
index d8be2bb..5f706c2 100644
--- a/src/websocket_program_account_subscriber.rs
+++ b/src/websocket_program_account_subscriber.rs
@@ -42,6 +42,7 @@ impl Event for ProgramAccountUpd
}
}
+#[derive(Clone)]
pub struct WebsocketProgramAccountOptions {
pub filters: Vec,
pub commitment: CommitmentConfig,
@@ -102,53 +103,82 @@ impl WebsocketProgramAccountSubscriber {
..RpcProgramAccountsConfig::default()
};
- let pubsub = PubsubClient::new(&self.url).await?;
let (unsub_tx, mut unsub_rx) = tokio::sync::mpsc::channel::<()>(1);
self.unsubscriber = Some(unsub_tx);
+ let mut attempt = 0;
+ let max_reconnection_attempts = 20;
+ let base_delay = tokio::time::Duration::from_secs(5);
+
+ let url = self.url.clone();
tokio::spawn({
let event_emitter = self.event_emitter.clone();
let mut latest_slot = 0;
let subscription_name = self.subscription_name;
async move {
- let (mut accounts, unsubscriber) = pubsub
- .program_subscribe(&drift::ID, Some(config))
- .await
- .unwrap();
loop {
- tokio::select! {
- message = accounts.next() => {
- match message {
- Some(message) => {
- let slot = message.context.slot;
- if slot >= latest_slot {
- latest_slot = slot;
- let pubkey = message.value.pubkey;
- let account_data = message.value.account.data;
- match decode(account_data) {
- Ok(data) => {
- let data_and_slot = DataAndSlot:: { slot, data };
- event_emitter.emit(subscription_name, Box::new(ProgramAccountUpdate::new(pubkey, data_and_slot)));
- },
- Err(e) => {
- error!("Error decoding account data {e}");
+ let pubsub = PubsubClient::new(&url).await?;
+ match pubsub
+ .program_subscribe(&drift::ID, Some(config.clone()))
+ .await
+ {
+ Ok((mut accounts, unsubscriber)) => loop {
+ tokio::select! {
+ message = accounts.next() => {
+ match message {
+ Some(message) => {
+ let slot = message.context.slot;
+ if slot >= latest_slot {
+ latest_slot = slot;
+ let pubkey = message.value.pubkey;
+ let account_data = message.value.account.data;
+ match decode(account_data) {
+ Ok(data) => {
+ let data_and_slot = DataAndSlot:: { slot, data };
+ event_emitter.emit(subscription_name, Box::new(ProgramAccountUpdate::new(pubkey, data_and_slot)));
+ },
+ Err(e) => {
+ error!("Error decoding account data {e}");
+ }
+ }
}
}
+ None => {
+ warn!("{} stream ended", subscription_name);
+ unsubscriber().await;
+ break;
+ }
}
}
- None => {
- warn!("{} stream ended", subscription_name);
+ _ = unsub_rx.recv() => {
+ debug!("Unsubscribing.");
unsubscriber().await;
- break;
+ return Ok(());
}
}
- }
- _ = unsub_rx.recv() => {
- debug!("Unsubscribing.");
- unsubscriber().await;
- break;
+ },
+ Err(_) => {
+ error!("Failed to subscribe to program stream, retrying.");
+ attempt += 1;
+ if attempt >= max_reconnection_attempts {
+ error!("Max reconnection attempts reached.");
+ return Err(SdkError::MaxReconnectionAttemptsReached);
+ }
}
}
+
+ if attempt >= max_reconnection_attempts {
+ error!("Max reconnection attempts reached.");
+ return Err(SdkError::MaxReconnectionAttemptsReached);
+ }
+
+ let delay_duration = base_delay * 2_u32.pow(attempt);
+ debug!(
+ "{}: Reconnecting in {:?}",
+ subscription_name, delay_duration
+ );
+ tokio::time::sleep(delay_duration).await;
+ attempt += 1;
}
}
});