Skip to content
This repository was archived by the owner on May 27, 2025. It is now read-only.

Commit e1c995b

Browse files
committed
add tests
1 parent c496b2b commit e1c995b

File tree

1 file changed

+213
-0
lines changed

1 file changed

+213
-0
lines changed

src/subscriber.rs

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,216 @@ where
159159
Ok(())
160160
}
161161
}
162+
163+
#[cfg(test)]
164+
mod tests {
165+
use super::*;
166+
use crate::metrics::Metrics;
167+
use axum::http::Uri;
168+
use futures::SinkExt;
169+
use std::net::SocketAddr;
170+
use std::sync::{Arc, Mutex};
171+
use tokio::net::{TcpListener, TcpStream};
172+
use tokio::sync::broadcast;
173+
use tokio::time::{sleep, timeout, Duration};
174+
use tokio_tungstenite::{accept_async, tungstenite::Message};
175+
176+
struct MockServer {
177+
addr: SocketAddr,
178+
message_sender: broadcast::Sender<String>,
179+
shutdown: CancellationToken,
180+
}
181+
182+
impl MockServer {
183+
async fn new() -> Self {
184+
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
185+
let addr = listener.local_addr().unwrap();
186+
let (tx, _) = broadcast::channel::<String>(100);
187+
let shutdown = CancellationToken::new();
188+
let shutdown_clone = shutdown.clone();
189+
let tx_clone = tx.clone();
190+
191+
tokio::spawn(async move {
192+
loop {
193+
select! {
194+
_ = shutdown_clone.cancelled() => {
195+
break;
196+
}
197+
accept_result = listener.accept() => {
198+
match accept_result {
199+
Ok((stream, _)) => {
200+
let tx = tx_clone.clone();
201+
let shutdown = shutdown_clone.clone();
202+
tokio::spawn(async move {
203+
Self::handle_connection(stream, tx, shutdown).await;
204+
});
205+
}
206+
Err(e) => {
207+
eprintln!("Failed to accept: {}", e);
208+
break;
209+
}
210+
}
211+
}
212+
}
213+
}
214+
});
215+
216+
Self {
217+
addr,
218+
message_sender: tx,
219+
shutdown,
220+
}
221+
}
222+
223+
async fn handle_connection(
224+
stream: TcpStream,
225+
tx: broadcast::Sender<String>,
226+
shutdown: CancellationToken,
227+
) {
228+
let ws_stream = match accept_async(stream).await {
229+
Ok(ws_stream) => ws_stream,
230+
Err(e) => {
231+
eprintln!("Failed to accept websocket: {}", e);
232+
return;
233+
}
234+
};
235+
236+
let (mut ws_sender, _) = ws_stream.split();
237+
238+
let mut rx = tx.subscribe();
239+
240+
loop {
241+
select! {
242+
_ = shutdown.cancelled() => {
243+
break;
244+
}
245+
msg = rx.recv() => {
246+
match msg {
247+
Ok(text) => {
248+
if let Err(e) = ws_sender.send(Message::Text(text.into())).await {
249+
eprintln!("Error sending message: {}", e);
250+
break;
251+
}
252+
}
253+
Err(_) => {
254+
break;
255+
}
256+
}
257+
}
258+
}
259+
}
260+
}
261+
262+
async fn send_message(
263+
&self,
264+
msg: &str,
265+
) -> Result<usize, broadcast::error::SendError<String>> {
266+
self.message_sender.send(msg.to_string())
267+
}
268+
269+
async fn shutdown(self) {
270+
self.shutdown.cancel();
271+
}
272+
273+
fn uri(&self) -> Uri {
274+
format!("ws://{}", self.addr)
275+
.parse()
276+
.expect("Failed to parse URI")
277+
}
278+
}
279+
280+
#[tokio::test]
281+
async fn test_multiple_subscribers_single_listener() {
282+
// Create two mock servers
283+
let server1 = MockServer::new().await;
284+
let server2 = MockServer::new().await;
285+
286+
// Create a receiver for the messages
287+
let received_messages = Arc::new(Mutex::new(Vec::new()));
288+
let received_clone = received_messages.clone();
289+
290+
// Create a listener function that will be shared by both subscribers
291+
let listener = move |data: String| {
292+
if let Ok(mut messages) = received_clone.lock() {
293+
messages.push(data);
294+
}
295+
};
296+
297+
// Create metrics
298+
let metrics = Arc::new(Metrics::default());
299+
300+
// Create cancellation token
301+
let token = CancellationToken::new();
302+
let token_clone1 = token.clone();
303+
let token_clone2 = token.clone();
304+
305+
// Create and run the first subscriber
306+
let uri1 = server1.uri();
307+
let listener_clone1 = listener.clone();
308+
let metrics_clone1 = metrics.clone();
309+
310+
let mut subscriber1 =
311+
WebsocketSubscriber::new(uri1.clone(), listener_clone1, 5, metrics_clone1);
312+
313+
// Create and run the second subscriber
314+
let uri2 = server2.uri();
315+
let listener_clone2 = listener.clone();
316+
let metrics_clone2 = metrics.clone();
317+
318+
let mut subscriber2 =
319+
WebsocketSubscriber::new(uri2.clone(), listener_clone2, 5, metrics_clone2);
320+
321+
// Spawn tasks for subscribers
322+
let task1 = tokio::spawn(async move {
323+
subscriber1.run(token_clone1).await;
324+
});
325+
326+
let task2 = tokio::spawn(async move {
327+
subscriber2.run(token_clone2).await;
328+
});
329+
330+
// Wait for connections to establish
331+
sleep(Duration::from_millis(500)).await;
332+
333+
// Send different messages from each server
334+
let _ = server1.send_message("Message from server 1").await;
335+
let _ = server2.send_message("Message from server 2").await;
336+
337+
// Wait for messages to be processed
338+
sleep(Duration::from_millis(500)).await;
339+
340+
// Send more messages to ensure continuous operation
341+
let _ = server1.send_message("Another message from server 1").await;
342+
let _ = server2.send_message("Another message from server 2").await;
343+
344+
// Wait for messages to be processed
345+
sleep(Duration::from_millis(500)).await;
346+
347+
// Cancel the token to shut down subscribers
348+
token.cancel();
349+
350+
// Wait for tasks to complete
351+
let _ = timeout(Duration::from_secs(1), task1).await;
352+
let _ = timeout(Duration::from_secs(1), task2).await;
353+
354+
// Shutdown the mock servers
355+
server1.shutdown().await;
356+
server2.shutdown().await;
357+
358+
// Verify that messages were received
359+
let messages = match received_messages.lock() {
360+
Ok(guard) => guard,
361+
Err(poisoned) => poisoned.into_inner(),
362+
};
363+
364+
assert_eq!(messages.len(), 4);
365+
366+
// Check that we received messages from both servers
367+
assert!(messages.contains(&"Message from server 1".to_string()));
368+
assert!(messages.contains(&"Message from server 2".to_string()));
369+
assert!(messages.contains(&"Another message from server 1".to_string()));
370+
assert!(messages.contains(&"Another message from server 2".to_string()));
371+
372+
assert!(messages.len() > 0);
373+
}
374+
}

0 commit comments

Comments
 (0)