diff --git a/Cargo.lock b/Cargo.lock index b56a492c96..a21b5f1c97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4423,6 +4423,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linked-hash-map" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" + [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -5647,10 +5653,11 @@ dependencies = [ [[package]] name = "pyth-lazer-client" -version = "0.1.3" +version = "1.0.0" dependencies = [ "alloy-primitives 0.8.25", "anyhow", + "backoff", "base64 0.22.1", "bincode 1.3.3", "bs58", @@ -5665,6 +5672,8 @@ dependencies = [ "tokio", "tokio-tungstenite 0.20.1", "tracing", + "tracing-subscriber", + "ttl_cache", "url", ] @@ -10295,6 +10304,15 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "ttl_cache" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4189890526f0168710b6ee65ceaedf1460c48a14318ceec933cb26baa492096a" +dependencies = [ + "linked-hash-map", +] + [[package]] name = "tungstenite" version = "0.20.1" diff --git a/lazer/sdk/rust/client/Cargo.toml b/lazer/sdk/rust/client/Cargo.toml index 3801762059..2f30cda4ae 100644 --- a/lazer/sdk/rust/client/Cargo.toml +++ b/lazer/sdk/rust/client/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pyth-lazer-client" -version = "0.1.3" +version = "1.0.0" edition = "2021" description = "A Rust client for Pyth Lazer" license = "Apache-2.0" @@ -17,6 +17,9 @@ anyhow = "1.0" tracing = "0.1" url = "2.4" derive_more = { version = "1.0.0", features = ["from"] } +backoff = { version = "0.4.0", features = ["futures", "tokio"] } +ttl_cache = "0.5.1" + [dev-dependencies] bincode = "1.3.3" @@ -25,3 +28,4 @@ hex = "0.4.3" libsecp256k1 = "0.7.1" bs58 = "0.5.1" alloy-primitives = "0.8.19" +tracing-subscriber = { version = "0.3.19", features = ["env-filter", "json"] } diff --git a/lazer/sdk/rust/client/examples/subscribe_price_feeds.rs b/lazer/sdk/rust/client/examples/subscribe_price_feeds.rs index 30efd2a8b8..e39c4dc41b 100644 --- a/lazer/sdk/rust/client/examples/subscribe_price_feeds.rs +++ b/lazer/sdk/rust/client/examples/subscribe_price_feeds.rs @@ -1,6 +1,9 @@ +use std::time::Duration; + use base64::Engine; -use futures_util::StreamExt; -use pyth_lazer_client::{AnyResponse, LazerClient}; +use pyth_lazer_client::backoff::PythLazerExponentialBackoffBuilder; +use pyth_lazer_client::client::PythLazerClientBuilder; +use pyth_lazer_client::ws_connection::AnyResponse; use pyth_lazer_protocol::message::{ EvmMessage, LeEcdsaMessage, LeUnsignedMessage, Message, SolanaMessage, }; @@ -9,8 +12,10 @@ use pyth_lazer_protocol::router::{ Channel, DeliveryFormat, FixedRate, Format, JsonBinaryEncoding, PriceFeedId, PriceFeedProperty, SubscriptionParams, SubscriptionParamsRepr, }; -use pyth_lazer_protocol::subscription::{Request, Response, SubscribeRequest, SubscriptionId}; +use pyth_lazer_protocol::subscription::{Response, SubscribeRequest, SubscriptionId}; use tokio::pin; +use tracing::level_filters::LevelFilter; +use tracing_subscriber::EnvFilter; fn get_lazer_access_token() -> String { // Place your access token in your env at LAZER_ACCESS_TOKEN or set it here @@ -20,11 +25,32 @@ fn get_lazer_access_token() -> String { #[tokio::main] async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env()?, + ) + .json() + .init(); + // Create and start the client - let mut client = LazerClient::new( - "wss://pyth-lazer.dourolabs.app/v1/stream", - &get_lazer_access_token(), - )?; + let mut client = PythLazerClientBuilder::new(get_lazer_access_token()) + // Optionally override the default endpoints + .with_endpoints(vec![ + "wss://pyth-lazer-0.dourolabs.app/v1/stream".parse()?, + "wss://pyth-lazer-1.dourolabs.app/v1/stream".parse()?, + ]) + // Optionally set the number of connections + .with_num_connections(4) + // Optionally set the backoff strategy + .with_backoff(PythLazerExponentialBackoffBuilder::default().build()) + // Optionally set the timeout for each connection + .with_timeout(Duration::from_secs(5)) + // Optionally set the channel capacity for responses + .with_channel_capacity(1000) + .build()?; + let stream = client.start().await?; pin!(stream); @@ -72,16 +98,16 @@ async fn main() -> anyhow::Result<()> { ]; for req in subscription_requests { - client.subscribe(Request::Subscribe(req)).await?; + client.subscribe(req).await?; } println!("Subscribed to price feeds. Waiting for updates..."); // Process the first few updates let mut count = 0; - while let Some(msg) = stream.next().await { + while let Some(msg) = stream.recv().await { // The stream gives us base64-encoded binary messages. We need to decode, parse, and verify them. - match msg? { + match msg { AnyResponse::Json(msg) => match msg { Response::StreamUpdated(update) => { println!("Received a JSON update for {:?}", update.subscription_id); @@ -189,8 +215,6 @@ async fn main() -> anyhow::Result<()> { println!("Unsubscribed from {sub_id:?}"); } - tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; - client.close().await?; Ok(()) } diff --git a/lazer/sdk/rust/client/src/backoff.rs b/lazer/sdk/rust/client/src/backoff.rs new file mode 100644 index 0000000000..263a09b57f --- /dev/null +++ b/lazer/sdk/rust/client/src/backoff.rs @@ -0,0 +1,69 @@ +use std::time::Duration; + +use backoff::{ + default::{INITIAL_INTERVAL_MILLIS, MAX_INTERVAL_MILLIS, MULTIPLIER, RANDOMIZATION_FACTOR}, + ExponentialBackoff, ExponentialBackoffBuilder, +}; + +#[derive(Debug)] +pub struct PythLazerExponentialBackoffBuilder { + initial_interval: Duration, + randomization_factor: f64, + multiplier: f64, + max_interval: Duration, +} + +impl Default for PythLazerExponentialBackoffBuilder { + fn default() -> Self { + Self { + initial_interval: Duration::from_millis(INITIAL_INTERVAL_MILLIS), + randomization_factor: RANDOMIZATION_FACTOR, + multiplier: MULTIPLIER, + max_interval: Duration::from_millis(MAX_INTERVAL_MILLIS), + } + } +} + +impl PythLazerExponentialBackoffBuilder { + pub fn new() -> Self { + Default::default() + } + + /// The initial retry interval. + pub fn with_initial_interval(&mut self, initial_interval: Duration) -> &mut Self { + self.initial_interval = initial_interval; + self + } + + /// The randomization factor to use for creating a range around the retry interval. + /// + /// A randomization factor of 0.5 results in a random period ranging between 50% below and 50% + /// above the retry interval. + pub fn with_randomization_factor(&mut self, randomization_factor: f64) -> &mut Self { + self.randomization_factor = randomization_factor; + self + } + + /// The value to multiply the current interval with for each retry attempt. + pub fn with_multiplier(&mut self, multiplier: f64) -> &mut Self { + self.multiplier = multiplier; + self + } + + /// The maximum value of the back off period. Once the retry interval reaches this + /// value it stops increasing. + pub fn with_max_interval(&mut self, max_interval: Duration) -> &mut Self { + self.max_interval = max_interval; + self + } + + pub fn build(&self) -> ExponentialBackoff { + ExponentialBackoffBuilder::default() + .with_initial_interval(self.initial_interval) + .with_randomization_factor(self.randomization_factor) + .with_multiplier(self.multiplier) + .with_max_interval(self.max_interval) + .with_max_elapsed_time(None) + .build() + } +} diff --git a/lazer/sdk/rust/client/src/client.rs b/lazer/sdk/rust/client/src/client.rs new file mode 100644 index 0000000000..3b3e38bf48 --- /dev/null +++ b/lazer/sdk/rust/client/src/client.rs @@ -0,0 +1,186 @@ +use std::time::Duration; + +use crate::{ + resilient_ws_connection::PythLazerResilientWSConnection, ws_connection::AnyResponse, + CHANNEL_CAPACITY, +}; +use anyhow::{bail, Result}; +use backoff::ExponentialBackoff; +use pyth_lazer_protocol::subscription::{SubscribeRequest, SubscriptionId}; +use tokio::sync::mpsc::{self, error::TrySendError}; +use tracing::{error, warn}; +use ttl_cache::TtlCache; +use url::Url; + +const DEDUP_CACHE_SIZE: usize = 100_000; +const DEDUP_TTL: Duration = Duration::from_secs(10); + +const DEFAULT_ENDPOINTS: [&str; 2] = [ + "wss://pyth-lazer-0.dourolabs.app/v1/stream", + "wss://pyth-lazer-1.dourolabs.app/v1/stream", +]; +const DEFAULT_NUM_CONNECTIONS: usize = 4; +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5); + +pub struct PythLazerClient { + endpoints: Vec, + access_token: String, + num_connections: usize, + ws_connections: Vec, + backoff: ExponentialBackoff, + timeout: Duration, + channel_capacity: usize, +} + +impl PythLazerClient { + /// Creates a new client instance + /// + /// # Arguments + /// * `endpoints` - A vector of endpoint URLs + /// * `access_token` - The access token for authentication + /// * `num_connections` - The number of WebSocket connections to maintain + pub fn new( + endpoints: Vec, + access_token: String, + num_connections: usize, + backoff: ExponentialBackoff, + timeout: Duration, + channel_capacity: usize, + ) -> Result { + if backoff.max_elapsed_time.is_some() { + bail!("max_elapsed_time is not supported in Pyth Lazer client"); + } + if endpoints.is_empty() { + bail!("At least one endpoint must be provided"); + } + Ok(Self { + endpoints, + access_token, + num_connections, + ws_connections: Vec::with_capacity(num_connections), + backoff, + timeout, + channel_capacity, + }) + } + + pub async fn start(&mut self) -> Result> { + let (sender, receiver) = mpsc::channel::(self.channel_capacity); + let (ws_connection_sender, mut ws_connection_receiver) = + mpsc::channel::(CHANNEL_CAPACITY); + + for i in 0..self.num_connections { + let endpoint = self.endpoints[i % self.endpoints.len()].clone(); + let connection = PythLazerResilientWSConnection::new( + endpoint, + self.access_token.clone(), + self.backoff.clone(), + self.timeout, + ws_connection_sender.clone(), + ); + self.ws_connections.push(connection); + } + + let mut seen_updates = TtlCache::new(DEDUP_CACHE_SIZE); + + tokio::spawn(async move { + while let Some(response) = ws_connection_receiver.recv().await { + let cache_key = response.cache_key(); + if seen_updates.contains_key(&cache_key) { + continue; + } + seen_updates.insert(cache_key, response.clone(), DEDUP_TTL); + + match sender.try_send(response) { + Ok(_) => (), + Err(TrySendError::Full(r)) => { + warn!("Sender channel is full, responses will be delayed"); + if sender.send(r).await.is_err() { + error!("Sender channel is closed, stopping client"); + } + } + Err(TrySendError::Closed(_)) => { + error!("Sender channel is closed, stopping client"); + } + } + } + }); + + Ok(receiver) + } + + pub async fn subscribe(&mut self, subscribe_request: SubscribeRequest) -> Result<()> { + for connection in &mut self.ws_connections { + connection.subscribe(subscribe_request.clone()).await?; + } + Ok(()) + } + + pub async fn unsubscribe(&mut self, subscription_id: SubscriptionId) -> Result<()> { + for connection in &mut self.ws_connections { + connection.unsubscribe(subscription_id).await?; + } + Ok(()) + } +} + +pub struct PythLazerClientBuilder { + endpoints: Vec, + access_token: String, + num_connections: usize, + backoff: ExponentialBackoff, + timeout: Duration, + channel_capacity: usize, +} + +impl PythLazerClientBuilder { + pub fn new(access_token: String) -> Self { + Self { + endpoints: DEFAULT_ENDPOINTS + .iter() + .map(|&s| s.parse().unwrap()) + .collect(), + access_token, + num_connections: DEFAULT_NUM_CONNECTIONS, + backoff: ExponentialBackoff::default(), + timeout: DEFAULT_TIMEOUT, + channel_capacity: CHANNEL_CAPACITY, + } + } + + pub fn with_endpoints(mut self, endpoints: Vec) -> Self { + self.endpoints = endpoints; + self + } + + pub fn with_num_connections(mut self, num_connections: usize) -> Self { + self.num_connections = num_connections; + self + } + + pub fn with_backoff(mut self, backoff: ExponentialBackoff) -> Self { + self.backoff = backoff; + self + } + + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + pub fn with_channel_capacity(mut self, channel_capacity: usize) -> Self { + self.channel_capacity = channel_capacity; + self + } + + pub fn build(self) -> Result { + PythLazerClient::new( + self.endpoints, + self.access_token, + self.num_connections, + self.backoff, + self.timeout, + self.channel_capacity, + ) + } +} diff --git a/lazer/sdk/rust/client/src/lib.rs b/lazer/sdk/rust/client/src/lib.rs index 30c1df8902..c62eab1ff4 100644 --- a/lazer/sdk/rust/client/src/lib.rs +++ b/lazer/sdk/rust/client/src/lib.rs @@ -1,138 +1,6 @@ -use anyhow::Result; -use derive_more::From; -use futures_util::{SinkExt, StreamExt, TryStreamExt}; -use pyth_lazer_protocol::{ - binary_update::BinaryWsUpdate, - subscription::{ErrorResponse, Request, Response, SubscriptionId, UnsubscribeRequest}, -}; -use tokio_tungstenite::{connect_async, tungstenite::Message}; -use url::Url; +const CHANNEL_CAPACITY: usize = 1000; -/// A WebSocket client for consuming Pyth Lazer price feed updates -/// -/// This client provides a simple interface to: -/// - Connect to a Lazer WebSocket endpoint -/// - Subscribe to price feed updates -/// - Receive updates as a stream of messages -/// -pub struct LazerClient { - endpoint: Url, - access_token: String, - ws_sender: Option< - futures_util::stream::SplitSink< - tokio_tungstenite::WebSocketStream< - tokio_tungstenite::MaybeTlsStream, - >, - Message, - >, - >, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash, From)] -pub enum AnyResponse { - Json(Response), - Binary(BinaryWsUpdate), -} - -impl LazerClient { - /// Creates a new Lazer client instance - /// - /// # Arguments - /// * `endpoint` - The WebSocket URL of the Lazer service - /// * `access_token` - Access token for authentication - /// - /// # Returns - /// Returns a new client instance (not yet connected) - pub fn new(endpoint: &str, access_token: &str) -> Result { - let endpoint = Url::parse(endpoint)?; - let access_token = access_token.to_string(); - Ok(Self { - endpoint, - access_token, - ws_sender: None, - }) - } - - /// Starts the WebSocket connection - /// - /// # Returns - /// Returns a stream of responses from the server - pub async fn start(&mut self) -> Result>> { - let url = self.endpoint.clone(); - let mut request = - tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(url)?; - - request.headers_mut().insert( - "Authorization", - format!("Bearer {}", self.access_token).parse().unwrap(), - ); - - let (ws_stream, _) = connect_async(request).await?; - let (ws_sender, ws_receiver) = ws_stream.split(); - - self.ws_sender = Some(ws_sender); - let response_stream = - ws_receiver - .map_err(anyhow::Error::from) - .try_filter_map(|msg| async { - let r: Result> = match msg { - Message::Text(text) => { - Ok(Some(serde_json::from_str::(&text)?.into())) - } - Message::Binary(data) => { - Ok(Some(BinaryWsUpdate::deserialize_slice(&data)?.into())) - } - Message::Close(_) => Ok(Some( - Response::Error(ErrorResponse { - error: "WebSocket connection closed".to_string(), - }) - .into(), - )), - _ => Ok(None), - }; - r - }); - - Ok(response_stream) - } - - /// Subscribes to price feed updates - /// - /// # Arguments - /// * `request` - A subscription request containing feed IDs and parameters - pub async fn subscribe(&mut self, request: Request) -> Result<()> { - if let Some(sender) = &mut self.ws_sender { - let msg = serde_json::to_string(&request)?; - sender.send(Message::Text(msg)).await?; - Ok(()) - } else { - anyhow::bail!("WebSocket connection not started") - } - } - - /// Unsubscribes from a previously subscribed feed - /// - /// # Arguments - /// * `subscription_id` - The ID of the subscription to cancel - pub async fn unsubscribe(&mut self, subscription_id: SubscriptionId) -> Result<()> { - if let Some(sender) = &mut self.ws_sender { - let request = Request::Unsubscribe(UnsubscribeRequest { subscription_id }); - let msg = serde_json::to_string(&request)?; - sender.send(Message::Text(msg)).await?; - Ok(()) - } else { - anyhow::bail!("WebSocket connection not started") - } - } - - /// Closes the WebSocket connection - pub async fn close(&mut self) -> Result<()> { - if let Some(sender) = &mut self.ws_sender { - sender.send(Message::Close(None)).await?; - self.ws_sender = None; - Ok(()) - } else { - anyhow::bail!("WebSocket connection not started") - } - } -} +pub mod backoff; +pub mod client; +pub mod resilient_ws_connection; +pub mod ws_connection; diff --git a/lazer/sdk/rust/client/src/resilient_ws_connection.rs b/lazer/sdk/rust/client/src/resilient_ws_connection.rs new file mode 100644 index 0000000000..70385d5946 --- /dev/null +++ b/lazer/sdk/rust/client/src/resilient_ws_connection.rs @@ -0,0 +1,210 @@ +use std::time::Duration; + +use backoff::{backoff::Backoff, ExponentialBackoff}; +use futures_util::StreamExt; +use pyth_lazer_protocol::subscription::{ + Request, SubscribeRequest, SubscriptionId, UnsubscribeRequest, +}; +use tokio::{pin, select, sync::mpsc, time::Instant}; +use tracing::{error, info, warn}; +use url::Url; + +use crate::{ + ws_connection::{AnyResponse, PythLazerWSConnection}, + CHANNEL_CAPACITY, +}; +use anyhow::{bail, Context, Result}; + +const BACKOFF_RESET_DURATION: Duration = Duration::from_secs(10); + +pub struct PythLazerResilientWSConnection { + request_sender: mpsc::Sender, +} + +impl PythLazerResilientWSConnection { + /// Creates a new resilient WebSocket client instance + /// + /// # Arguments + /// * `endpoint` - The WebSocket URL of the Lazer service + /// * `access_token` - Access token for authentication + /// * `sender` - A sender to send responses back to the client + /// + /// # Returns + /// Returns a new client instance (not yet connected) + pub fn new( + endpoint: Url, + access_token: String, + backoff: ExponentialBackoff, + timeout: Duration, + sender: mpsc::Sender, + ) -> Self { + let (request_sender, mut request_receiver) = mpsc::channel(CHANNEL_CAPACITY); + let mut task = + PythLazerResilientWSConnectionTask::new(endpoint, access_token, backoff, timeout); + + tokio::spawn(async move { + if let Err(e) = task.run(sender, &mut request_receiver).await { + error!("Resilient WebSocket connection task failed: {}", e); + } + }); + + Self { request_sender } + } + + pub async fn subscribe(&mut self, request: SubscribeRequest) -> Result<()> { + self.request_sender + .send(Request::Subscribe(request)) + .await + .context("Failed to send subscribe request")?; + Ok(()) + } + + pub async fn unsubscribe(&mut self, subscription_id: SubscriptionId) -> Result<()> { + self.request_sender + .send(Request::Unsubscribe(UnsubscribeRequest { subscription_id })) + .await + .context("Failed to send unsubscribe request")?; + Ok(()) + } +} + +struct PythLazerResilientWSConnectionTask { + endpoint: Url, + access_token: String, + subscriptions: Vec, + backoff: ExponentialBackoff, + timeout: Duration, +} + +impl PythLazerResilientWSConnectionTask { + pub fn new( + endpoint: Url, + access_token: String, + backoff: ExponentialBackoff, + timeout: Duration, + ) -> Self { + Self { + endpoint, + access_token, + subscriptions: Vec::new(), + backoff, + timeout, + } + } + + pub async fn run( + &mut self, + response_sender: mpsc::Sender, + request_receiver: &mut mpsc::Receiver, + ) -> Result<()> { + loop { + let start_time = Instant::now(); + if let Err(e) = self.start(response_sender.clone(), request_receiver).await { + // If a connection was working for BACKOFF_RESET_DURATION + // and timeout + 1sec, it was considered successful therefore reset the backoff + if start_time.elapsed() > BACKOFF_RESET_DURATION + && start_time.elapsed() > self.timeout + Duration::from_secs(1) + { + self.backoff.reset(); + } + + let delay = self.backoff.next_backoff(); + match delay { + Some(d) => { + info!("WebSocket connection failed: {}. Retrying in {:?}", e, d); + tokio::time::sleep(d).await; + } + None => { + bail!( + "Max retries reached for WebSocket connection to {}, this should never happen, please contact developers", + self.endpoint + ); + } + } + } + } + } + + pub async fn start( + &mut self, + sender: mpsc::Sender, + request_receiver: &mut mpsc::Receiver, + ) -> Result<()> { + let mut ws_connection = + PythLazerWSConnection::new(self.endpoint.clone(), self.access_token.clone())?; + let stream = ws_connection.start().await?; + pin!(stream); + + for subscription in self.subscriptions.clone() { + ws_connection + .send_request(Request::Subscribe(subscription)) + .await?; + } + loop { + let timeout_response = tokio::time::timeout(self.timeout, stream.next()); + + select! { + response = timeout_response => { + match response { + Ok(Some(response)) => match response { + Ok(response) => { + sender + .send(response) + .await + .context("Failed to send response")?; + } + Err(e) => { + bail!("WebSocket stream error: {}", e); + } + }, + Ok(None) => { + bail!("WebSocket stream ended unexpectedly"); + } + Err(_elapsed) => { + bail!("WebSocket stream timed out"); + } + } + } + Some(request) = request_receiver.recv() => { + match request { + Request::Subscribe(request) => { + self.subscribe(&mut ws_connection, request).await?; + } + Request::Unsubscribe(request) => { + self.unsubscribe(&mut ws_connection, request).await?; + } + } + } + } + } + } + + pub async fn subscribe( + &mut self, + ws_connection: &mut PythLazerWSConnection, + request: SubscribeRequest, + ) -> Result<()> { + self.subscriptions.push(request.clone()); + ws_connection.subscribe(request).await + } + + pub async fn unsubscribe( + &mut self, + ws_connection: &mut PythLazerWSConnection, + request: UnsubscribeRequest, + ) -> Result<()> { + if let Some(index) = self + .subscriptions + .iter() + .position(|r| r.subscription_id == request.subscription_id) + { + self.subscriptions.remove(index); + } else { + warn!( + "Unsubscribe called for non-existent subscription: {:?}", + request.subscription_id + ); + } + ws_connection.unsubscribe(request).await + } +} diff --git a/lazer/sdk/rust/client/src/ws_connection.rs b/lazer/sdk/rust/client/src/ws_connection.rs new file mode 100644 index 0000000000..385bd2efd7 --- /dev/null +++ b/lazer/sdk/rust/client/src/ws_connection.rs @@ -0,0 +1,144 @@ +use std::hash::{DefaultHasher, Hash, Hasher}; + +use anyhow::Result; +use derive_more::From; +use futures_util::{SinkExt, StreamExt, TryStreamExt}; +use pyth_lazer_protocol::{ + binary_update::BinaryWsUpdate, + subscription::{ErrorResponse, Request, Response, SubscribeRequest, UnsubscribeRequest}, +}; +use tokio_tungstenite::{connect_async, tungstenite::Message}; +use url::Url; + +/// A WebSocket client for consuming Pyth Lazer price feed updates +/// +/// This client provides a simple interface to: +/// - Connect to a Lazer WebSocket endpoint +/// - Subscribe to price feed updates +/// - Receive updates as a stream of messages +/// +pub struct PythLazerWSConnection { + endpoint: Url, + access_token: String, + ws_sender: Option< + futures_util::stream::SplitSink< + tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + Message, + >, + >, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, From)] +pub enum AnyResponse { + Json(Response), + Binary(BinaryWsUpdate), +} + +impl AnyResponse { + pub fn cache_key(&self) -> u64 { + let mut hasher = DefaultHasher::new(); + self.hash(&mut hasher); + hasher.finish() + } +} +impl PythLazerWSConnection { + /// Creates a new Lazer client instance + /// + /// # Arguments + /// * `endpoint` - The WebSocket URL of the Lazer service + /// * `access_token` - Access token for authentication + /// + /// # Returns + /// Returns a new client instance (not yet connected) + pub fn new(endpoint: Url, access_token: String) -> Result { + Ok(Self { + endpoint, + access_token, + ws_sender: None, + }) + } + + /// Starts the WebSocket connection + /// + /// # Returns + /// Returns a stream of responses from the server + pub async fn start(&mut self) -> Result>> { + let url = self.endpoint.clone(); + let mut request = + tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(url)?; + + request.headers_mut().insert( + "Authorization", + format!("Bearer {}", self.access_token).parse().unwrap(), + ); + + let (ws_stream, _) = connect_async(request).await?; + let (ws_sender, ws_receiver) = ws_stream.split(); + + self.ws_sender = Some(ws_sender); + let response_stream = + ws_receiver + .map_err(anyhow::Error::from) + .try_filter_map(|msg| async { + let r: Result> = match msg { + Message::Text(text) => { + Ok(Some(serde_json::from_str::(&text)?.into())) + } + Message::Binary(data) => { + Ok(Some(BinaryWsUpdate::deserialize_slice(&data)?.into())) + } + Message::Close(_) => Ok(Some( + Response::Error(ErrorResponse { + error: "WebSocket connection closed".to_string(), + }) + .into(), + )), + _ => Ok(None), + }; + r + }); + + Ok(response_stream) + } + + pub async fn send_request(&mut self, request: Request) -> Result<()> { + if let Some(sender) = &mut self.ws_sender { + let msg = serde_json::to_string(&request)?; + sender.send(Message::Text(msg)).await?; + Ok(()) + } else { + anyhow::bail!("WebSocket connection not started") + } + } + + /// Subscribes to price feed updates + /// + /// # Arguments + /// * `request` - A subscription request containing feed IDs and parameters + pub async fn subscribe(&mut self, request: SubscribeRequest) -> Result<()> { + let request = Request::Subscribe(request); + self.send_request(request).await + } + + /// Unsubscribes from a previously subscribed feed + /// + /// # Arguments + /// * `subscription_id` - The ID of the subscription to cancel + pub async fn unsubscribe(&mut self, request: UnsubscribeRequest) -> Result<()> { + let request = Request::Unsubscribe(request); + self.send_request(request).await + } + + /// Closes the WebSocket connection + pub async fn close(&mut self) -> Result<()> { + if let Some(sender) = &mut self.ws_sender { + sender.send(Message::Close(None)).await?; + self.ws_sender = None; + Ok(()) + } else { + anyhow::bail!("WebSocket connection not started") + } + } +}