diff --git a/Cargo.lock b/Cargo.lock index 1e7b482..19ea8e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -926,6 +926,19 @@ dependencies = [ "zeroize", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.3", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.5.0" @@ -1103,6 +1116,7 @@ dependencies = [ "borsh 1.3.1", "bytemuck", "bytes", + "dashmap", "drift", "env_logger 0.10.2", "fnv", diff --git a/Cargo.toml b/Cargo.toml index 7ea2d5f..fb77eca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ thiserror = "1.0.38" tokio = { version = "1.34.0", features = ["full"] } tokio-tungstenite = { version = "0.21.0", features = ["native-tls"] } regex = "1.10.2" +dashmap = "5.5.3" [dev-dependencies] pyth = { git = "https://github.com/drift-labs/protocol-v2.git", tag = "v2.67.0", features = [ diff --git a/src/lib.rs b/src/lib.rs index da92086..88afa48 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,6 +71,8 @@ pub mod dlob; pub mod event_subscriber; pub mod slot_subscriber; +pub mod usermap; + use types::*; /// Provides solana Account fetching API diff --git a/src/usermap.rs b/src/usermap.rs new file mode 100644 index 0000000..c88c697 --- /dev/null +++ b/src/usermap.rs @@ -0,0 +1,212 @@ +use std::str::FromStr; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; + +use crate::event_emitter::EventEmitter; +use crate::memcmp::{get_non_idle_user_filter, get_user_filter}; +use crate::utils::{decode, get_ws_url}; +use crate::websocket_program_account_subscriber::{ + ProgramAccountUpdate, WebsocketProgramAccountOptions, WebsocketProgramAccountSubscriber, +}; +use crate::SdkResult; +use anchor_lang::AccountDeserialize; +use dashmap::DashMap; +use drift::state::user::User; +use serde_json::json; +use solana_account_decoder::UiAccountEncoding; +use solana_client::nonblocking::rpc_client::RpcClient; +use solana_client::rpc_config::{RpcAccountInfoConfig, RpcProgramAccountsConfig}; +use solana_client::rpc_request::RpcRequest; +use solana_client::rpc_response::{OptionalContext, RpcKeyedAccount}; +use solana_sdk::commitment_config::CommitmentConfig; +use solana_sdk::pubkey::Pubkey; + +pub struct Usermap { + subscribed: bool, + subscription: WebsocketProgramAccountSubscriber, + usermap: Arc>, + sync_lock: Option>, + latest_slot: Arc, + commitment: CommitmentConfig, + rpc: RpcClient, +} + +impl Usermap { + pub fn new(commitment: CommitmentConfig, endpoint: String, sync: bool) -> Self { + let filters = vec![get_user_filter(), get_non_idle_user_filter()]; + let options = WebsocketProgramAccountOptions { + filters, + commitment, + encoding: UiAccountEncoding::Base64, + }; + let event_emitter = EventEmitter::new(); + + let url = get_ws_url(&endpoint.clone()).unwrap(); + + let subscription = + WebsocketProgramAccountSubscriber::new("usermap", url, options, event_emitter); + + let usermap = Arc::new(DashMap::new()); + + let rpc = RpcClient::new_with_commitment(endpoint.clone(), commitment); + + let sync_lock = if sync { Some(Mutex::new(())) } else { None }; + + Self { + subscribed: false, + subscription, + usermap, + sync_lock, + latest_slot: Arc::new(AtomicU64::new(0)), + commitment, + rpc, + } + } + + pub async fn subscribe(&mut self) -> SdkResult<()> { + if let Some(_) = self.sync_lock { + self.sync().await?; + } + + if !self.subscribed { + self.subscription.subscribe::().await?; + self.subscribed = true; + } + + let usermap = self.usermap.clone(); + let latest_slot = self.latest_slot.clone(); + + self.subscription + .event_emitter + .subscribe("usermap", move |event| { + if let Some(update) = event.as_any().downcast_ref::>() { + let user_data_and_slot = update.data_and_slot.clone(); + let user_pubkey = update.pubkey.to_string(); + if update.data_and_slot.slot > latest_slot.load(Ordering::Relaxed) { + latest_slot.store(update.data_and_slot.slot, Ordering::Relaxed); + } + usermap.insert(user_pubkey, user_data_and_slot.data); + } + }); + + Ok(()) + } + + pub async fn unsubscribe(&mut self) -> SdkResult<()> { + if self.subscribed { + self.subscription.unsubscribe().await?; + self.subscribed = false; + self.usermap.clear(); + self.latest_slot.store(0, Ordering::Relaxed); + } + Ok(()) + } + + pub fn size(&self) -> usize { + self.usermap.len() + } + + pub fn contains(&self, pubkey: &str) -> bool { + self.usermap.contains_key(pubkey) + } + + pub fn get(&self, pubkey: &str) -> Option { + self.usermap.get(pubkey).map(|user| user.value().clone()) + } + + pub async fn must_get(&self, pubkey: &str) -> SdkResult { + if let Some(user) = self.get(pubkey) { + Ok(user) + } else { + let user_data = self + .rpc + .get_account_data(&Pubkey::from_str(pubkey).unwrap()) + .await?; + let user = User::try_deserialize(&mut user_data.as_slice()).unwrap(); + self.usermap.insert(pubkey.to_string(), user.clone()); + Ok(self.get(pubkey).unwrap()) + } + } + + async fn sync(&mut self) -> SdkResult<()> { + let sync_lock = self.sync_lock.as_ref().expect("expected sync lock"); + + let lock = match sync_lock.try_lock() { + Ok(lock) => lock, + Err(_) => return Ok(()), + }; + + let account_config = RpcAccountInfoConfig { + commitment: Some(self.commitment), + encoding: Some(self.subscription.options.encoding), + ..RpcAccountInfoConfig::default() + }; + + let gpa_config = RpcProgramAccountsConfig { + filters: Some(self.subscription.options.filters.clone()), + account_config, + with_context: Some(true), + }; + + let response = self + .rpc + .send::>>( + RpcRequest::GetProgramAccounts, + json!([drift::id().to_string(), gpa_config]), + ) + .await?; + + if let OptionalContext::Context(accounts) = response { + for account in accounts.value { + let pubkey = account.pubkey; + let user_data = account.account.data; + let data = decode::(user_data)?; + self.usermap.insert(pubkey, data); + } + + self.latest_slot + .store(accounts.context.slot, Ordering::Relaxed); + } + + drop(lock); + Ok(()) + } + + pub fn get_latest_slot(&self) -> u64 { + self.latest_slot.load(Ordering::Relaxed) + } +} + +#[cfg(test)] +mod tests { + + #[tokio::test] + #[cfg(rpc_tests)] + async fn test_usermap() { + use crate::usermap::Usermap; + use solana_sdk::commitment_config::CommitmentConfig; + use solana_sdk::commitment_config::CommitmentLevel; + + let endpoint = "rpc_url".to_string(); + let commitment = CommitmentConfig { + commitment: CommitmentLevel::Processed, + }; + + let mut usermap = Usermap::new(commitment, endpoint, true); + usermap.subscribe().await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; + + dbg!(usermap.size()); + assert!(usermap.size() > 50000); + + dbg!(usermap.get_latest_slot()); + + usermap.unsubscribe().await.unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; + + assert_eq!(usermap.size(), 0); + assert_eq!(usermap.subscribed, false); + } +} diff --git a/src/utils.rs b/src/utils.rs index cb1cfd9..2b8c82e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -85,6 +85,18 @@ pub fn http_to_ws(url: &str) -> Result { Ok(format!("{}/ws", base_url.trim_end_matches('/'))) } +pub fn get_ws_url(url: &str) -> Result { + let base_url = if url.starts_with("http://") { + url.replacen("http://", "ws://", 1) + } else if url.starts_with("https://") { + url.replacen("https://", "wss://", 1) + } else { + return Err("Invalid URL scheme"); + }; + + Ok(base_url) +} + pub fn dlob_subscribe_ws_json(market: &str) -> String { json!({ "type": "subscribe", diff --git a/src/websocket_program_account_subscriber.rs b/src/websocket_program_account_subscriber.rs index 42e1186..d8be2bb 100644 --- a/src/websocket_program_account_subscriber.rs +++ b/src/websocket_program_account_subscriber.rs @@ -51,7 +51,7 @@ pub struct WebsocketProgramAccountOptions { pub struct WebsocketProgramAccountSubscriber { subscription_name: &'static str, url: String, - options: WebsocketProgramAccountOptions, + pub(crate) options: WebsocketProgramAccountOptions, pub subscribed: bool, pub event_emitter: EventEmitter, unsubscriber: Option>,