@@ -159,3 +159,216 @@ where
159159 Ok ( ( ) )
160160 }
161161}
162+
163+ #[ cfg( test) ]
164+ mod tests {
165+ use super :: * ;
166+ use crate :: metrics:: Metrics ;
167+ use axum:: http:: Uri ;
168+ use futures:: SinkExt ;
169+ use std:: net:: SocketAddr ;
170+ use std:: sync:: { Arc , Mutex } ;
171+ use tokio:: net:: { TcpListener , TcpStream } ;
172+ use tokio:: sync:: broadcast;
173+ use tokio:: time:: { sleep, timeout, Duration } ;
174+ use tokio_tungstenite:: { accept_async, tungstenite:: Message } ;
175+
176+ struct MockServer {
177+ addr : SocketAddr ,
178+ message_sender : broadcast:: Sender < String > ,
179+ shutdown : CancellationToken ,
180+ }
181+
182+ impl MockServer {
183+ async fn new ( ) -> Self {
184+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ;
185+ let addr = listener. local_addr ( ) . unwrap ( ) ;
186+ let ( tx, _) = broadcast:: channel :: < String > ( 100 ) ;
187+ let shutdown = CancellationToken :: new ( ) ;
188+ let shutdown_clone = shutdown. clone ( ) ;
189+ let tx_clone = tx. clone ( ) ;
190+
191+ tokio:: spawn ( async move {
192+ loop {
193+ select ! {
194+ _ = shutdown_clone. cancelled( ) => {
195+ break ;
196+ }
197+ accept_result = listener. accept( ) => {
198+ match accept_result {
199+ Ok ( ( stream, _) ) => {
200+ let tx = tx_clone. clone( ) ;
201+ let shutdown = shutdown_clone. clone( ) ;
202+ tokio:: spawn( async move {
203+ Self :: handle_connection( stream, tx, shutdown) . await ;
204+ } ) ;
205+ }
206+ Err ( e) => {
207+ eprintln!( "Failed to accept: {}" , e) ;
208+ break ;
209+ }
210+ }
211+ }
212+ }
213+ }
214+ } ) ;
215+
216+ Self {
217+ addr,
218+ message_sender : tx,
219+ shutdown,
220+ }
221+ }
222+
223+ async fn handle_connection (
224+ stream : TcpStream ,
225+ tx : broadcast:: Sender < String > ,
226+ shutdown : CancellationToken ,
227+ ) {
228+ let ws_stream = match accept_async ( stream) . await {
229+ Ok ( ws_stream) => ws_stream,
230+ Err ( e) => {
231+ eprintln ! ( "Failed to accept websocket: {}" , e) ;
232+ return ;
233+ }
234+ } ;
235+
236+ let ( mut ws_sender, _) = ws_stream. split ( ) ;
237+
238+ let mut rx = tx. subscribe ( ) ;
239+
240+ loop {
241+ select ! {
242+ _ = shutdown. cancelled( ) => {
243+ break ;
244+ }
245+ msg = rx. recv( ) => {
246+ match msg {
247+ Ok ( text) => {
248+ if let Err ( e) = ws_sender. send( Message :: Text ( text. into( ) ) ) . await {
249+ eprintln!( "Error sending message: {}" , e) ;
250+ break ;
251+ }
252+ }
253+ Err ( _) => {
254+ break ;
255+ }
256+ }
257+ }
258+ }
259+ }
260+ }
261+
262+ async fn send_message (
263+ & self ,
264+ msg : & str ,
265+ ) -> Result < usize , broadcast:: error:: SendError < String > > {
266+ self . message_sender . send ( msg. to_string ( ) )
267+ }
268+
269+ async fn shutdown ( self ) {
270+ self . shutdown . cancel ( ) ;
271+ }
272+
273+ fn uri ( & self ) -> Uri {
274+ format ! ( "ws://{}" , self . addr)
275+ . parse ( )
276+ . expect ( "Failed to parse URI" )
277+ }
278+ }
279+
280+ #[ tokio:: test]
281+ async fn test_multiple_subscribers_single_listener ( ) {
282+ // Create two mock servers
283+ let server1 = MockServer :: new ( ) . await ;
284+ let server2 = MockServer :: new ( ) . await ;
285+
286+ // Create a receiver for the messages
287+ let received_messages = Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ;
288+ let received_clone = received_messages. clone ( ) ;
289+
290+ // Create a listener function that will be shared by both subscribers
291+ let listener = move |data : String | {
292+ if let Ok ( mut messages) = received_clone. lock ( ) {
293+ messages. push ( data) ;
294+ }
295+ } ;
296+
297+ // Create metrics
298+ let metrics = Arc :: new ( Metrics :: default ( ) ) ;
299+
300+ // Create cancellation token
301+ let token = CancellationToken :: new ( ) ;
302+ let token_clone1 = token. clone ( ) ;
303+ let token_clone2 = token. clone ( ) ;
304+
305+ // Create and run the first subscriber
306+ let uri1 = server1. uri ( ) ;
307+ let listener_clone1 = listener. clone ( ) ;
308+ let metrics_clone1 = metrics. clone ( ) ;
309+
310+ let mut subscriber1 =
311+ WebsocketSubscriber :: new ( uri1. clone ( ) , listener_clone1, 5 , metrics_clone1) ;
312+
313+ // Create and run the second subscriber
314+ let uri2 = server2. uri ( ) ;
315+ let listener_clone2 = listener. clone ( ) ;
316+ let metrics_clone2 = metrics. clone ( ) ;
317+
318+ let mut subscriber2 =
319+ WebsocketSubscriber :: new ( uri2. clone ( ) , listener_clone2, 5 , metrics_clone2) ;
320+
321+ // Spawn tasks for subscribers
322+ let task1 = tokio:: spawn ( async move {
323+ subscriber1. run ( token_clone1) . await ;
324+ } ) ;
325+
326+ let task2 = tokio:: spawn ( async move {
327+ subscriber2. run ( token_clone2) . await ;
328+ } ) ;
329+
330+ // Wait for connections to establish
331+ sleep ( Duration :: from_millis ( 500 ) ) . await ;
332+
333+ // Send different messages from each server
334+ let _ = server1. send_message ( "Message from server 1" ) . await ;
335+ let _ = server2. send_message ( "Message from server 2" ) . await ;
336+
337+ // Wait for messages to be processed
338+ sleep ( Duration :: from_millis ( 500 ) ) . await ;
339+
340+ // Send more messages to ensure continuous operation
341+ let _ = server1. send_message ( "Another message from server 1" ) . await ;
342+ let _ = server2. send_message ( "Another message from server 2" ) . await ;
343+
344+ // Wait for messages to be processed
345+ sleep ( Duration :: from_millis ( 500 ) ) . await ;
346+
347+ // Cancel the token to shut down subscribers
348+ token. cancel ( ) ;
349+
350+ // Wait for tasks to complete
351+ let _ = timeout ( Duration :: from_secs ( 1 ) , task1) . await ;
352+ let _ = timeout ( Duration :: from_secs ( 1 ) , task2) . await ;
353+
354+ // Shutdown the mock servers
355+ server1. shutdown ( ) . await ;
356+ server2. shutdown ( ) . await ;
357+
358+ // Verify that messages were received
359+ let messages = match received_messages. lock ( ) {
360+ Ok ( guard) => guard,
361+ Err ( poisoned) => poisoned. into_inner ( ) ,
362+ } ;
363+
364+ assert_eq ! ( messages. len( ) , 4 ) ;
365+
366+ // Check that we received messages from both servers
367+ assert ! ( messages. contains( & "Message from server 1" . to_string( ) ) ) ;
368+ assert ! ( messages. contains( & "Message from server 2" . to_string( ) ) ) ;
369+ assert ! ( messages. contains( & "Another message from server 1" . to_string( ) ) ) ;
370+ assert ! ( messages. contains( & "Another message from server 2" . to_string( ) ) ) ;
371+
372+ assert ! ( messages. len( ) > 0 ) ;
373+ }
374+ }
0 commit comments