From c39b25edcf9be07366f3d657cb416939f9774a6f Mon Sep 17 00:00:00 2001 From: Mahmoud Mazouz Date: Thu, 3 Oct 2024 13:26:26 +0000 Subject: [PATCH] Encapsulate `flume::Receiver` --- zenoh/src/api/handlers/fifo.rs | 212 ++++++++++++++++++++++++++++++++- zenoh/tests/matching.rs | 29 +++-- 2 files changed, 230 insertions(+), 11 deletions(-) diff --git a/zenoh/src/api/handlers/fifo.rs b/zenoh/src/api/handlers/fifo.rs index 44a542e538..eb6d7df54c 100644 --- a/zenoh/src/api/handlers/fifo.rs +++ b/zenoh/src/api/handlers/fifo.rs @@ -14,7 +14,12 @@ //! Callback handler trait. -use std::sync::Arc; +use std::{ + sync::Arc, + time::{Duration, Instant}, +}; + +use zenoh_result::ZResult; use crate::api::handlers::{callback::Callback, IntoHandler, API_DATA_RECEPTION_CHANNEL_SIZE}; @@ -41,11 +46,212 @@ impl Default for FifoChannel { } } +/// [`FifoChannel`] handler. +#[derive(Debug, Clone)] +pub struct FifoChannelHandler(flume::Receiver); + impl IntoHandler for FifoChannel { - type Handler = flume::Receiver; + type Handler = FifoChannelHandler; fn into_handler(self) -> (Callback, Self::Handler) { - flume::bounded(self.capacity).into_handler() + let (sender, receiver) = flume::bounded(self.capacity); + ( + Callback::new(Arc::new(move |t| { + if let Err(error) = sender.send(t) { + tracing::error!(%error) + } + })), + FifoChannelHandler(receiver), + ) + } +} + +impl FifoChannelHandler { + /// Attempt to fetch an incoming value from the channel associated with this receiver, returning + /// an error if the channel is empty or if all senders have been dropped. + pub fn try_recv(&self) -> ZResult { + self.0.try_recv().map_err(Into::into) + } + + /// Wait for an incoming value from the channel associated with this receiver, returning an + /// error if all senders have been dropped. + pub fn recv(&self) -> ZResult { + self.0.recv().map_err(Into::into) + } + + /// Wait for an incoming value from the channel associated with this receiver, returning an + /// error if all senders have been dropped or the deadline has passed. + pub fn recv_deadline(&self, deadline: Instant) -> ZResult { + self.0.recv_deadline(deadline).map_err(Into::into) + } + + /// Wait for an incoming value from the channel associated with this receiver, returning an + /// error if all senders have been dropped or the timeout has expired. + pub fn recv_timeout(&self, dur: Duration) -> ZResult { + self.0 + .recv_deadline(Instant::now().checked_add(dur).unwrap()) + .map_err(Into::into) + } + + /// Create a blocking iterator over the values received on the channel that finishes iteration + /// when all senders have been dropped. + pub fn iter(&self) -> Iter<'_, T> { + Iter(self.0.iter()) + } + + /// A non-blocking iterator over the values received on the channel that finishes iteration when + /// all senders have been dropped or the channel is empty. + pub fn try_iter(&self) -> TryIter<'_, T> { + TryIter(self.0.try_iter()) + } + + /// Take all msgs currently sitting in the channel and produce an iterator over them. Unlike + /// `try_iter`, the iterator will not attempt to fetch any more values from the channel once + /// the function has been called. + pub fn drain(&self) -> Drain<'_, T> { + Drain(self.0.drain()) + } + + /// Returns true if all senders for this channel have been dropped. + pub fn is_disconnected(&self) -> bool { + self.0.is_disconnected() + } + + /// Returns true if the channel is empty. + /// Note: Zero-capacity channels are always empty. + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// Returns true if the channel is full. + /// Note: Zero-capacity channels are always full. + pub fn is_full(&self) -> bool { + self.0.is_full() + } + + /// Returns the number of messages in the channel. + pub fn len(&self) -> usize { + self.0.len() + } + + /// If the channel is bounded, returns its capacity. + pub fn capacity(&self) -> Option { + self.0.capacity() + } + + /// Get the number of senders that currently exist. + pub fn sender_count(&self) -> usize { + self.0.sender_count() + } + + /// Get the number of receivers that currently exist, including this one. + pub fn receiver_count(&self) -> usize { + self.0.receiver_count() + } + + /// Returns whether the receivers are belong to the same channel. + pub fn same_channel(&self, other: &Self) -> bool { + self.0.same_channel(&other.0) + } +} + +/// This exists as a shorthand for [`FifoChannelHandler::iter`]. +impl<'a, T> IntoIterator for &'a FifoChannelHandler { + type Item = T; + type IntoIter = Iter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + Iter(self.0.iter()) + } +} + +impl IntoIterator for FifoChannelHandler { + type Item = T; + type IntoIter = IntoIter; + + /// Creates a self-owned but semantically equivalent alternative to [`FifoChannelHandler::iter`]. + fn into_iter(self) -> Self::IntoIter { + IntoIter(self.0.into_iter()) + } +} + +/// An iterator over the msgs received from a channel. +pub struct Iter<'a, T>(flume::Iter<'a, T>); + +impl<'a, T> Iterator for Iter<'a, T> { + type Item = T; + + fn next(&mut self) -> Option { + self.0.next() + } +} + +/// An non-blocking iterator over the msgs received from a channel. +pub struct TryIter<'a, T>(flume::TryIter<'a, T>); + +impl<'a, T> Iterator for TryIter<'a, T> { + type Item = T; + + fn next(&mut self) -> Option { + self.0.next() + } +} + +/// An fixed-sized iterator over the msgs drained from a channel. +#[derive(Debug)] +pub struct Drain<'a, T>(flume::Drain<'a, T>); + +impl<'a, T> Iterator for Drain<'a, T> { + type Item = T; + + fn next(&mut self) -> Option { + self.0.next() + } +} + +impl<'a, T> ExactSizeIterator for Drain<'a, T> { + fn len(&self) -> usize { + self.0.len() + } +} + +/// An owned iterator over the msgs received from a channel. +pub struct IntoIter(flume::IntoIter); + +impl Iterator for IntoIter { + type Item = T; + + fn next(&mut self) -> Option { + self.0.next() + } +} + +impl FifoChannelHandler { + /// Asynchronously receive a value from the channel, returning an error if all senders have been + /// dropped. If the channel is empty, the returned future will yield to the async runtime. + pub async fn recv_async(&self) -> ZResult { + self.0.recv_async().await.map_err(Into::into) + } + + /// Convert this receiver into a future that asynchronously receives a single message from the + /// channel, returning an error if all senders have been dropped. If the channel is empty, this + /// future will yield to the async runtime. + pub async fn into_recv_async(self) -> ZResult { + self.0.into_recv_async().await.map_err(Into::into) + } + + /// Create an asynchronous stream that uses this receiver to asynchronously receive messages + /// from the channel. The receiver will continue to be usable after the stream has been dropped. + pub fn stream(&self) -> impl futures::Stream + '_ { + self.0.stream() + } +} + +impl<'a, T: 'a> FifoChannelHandler { + /// Convert this receiver into a stream that allows asynchronously receiving messages from the + /// channel. + pub fn into_stream(self) -> impl futures::Stream + 'a { + self.0.into_stream() } } diff --git a/zenoh/tests/matching.rs b/zenoh/tests/matching.rs index efa377863d..16f1376507 100644 --- a/zenoh/tests/matching.rs +++ b/zenoh/tests/matching.rs @@ -15,7 +15,6 @@ use std::time::Duration; -use flume::RecvTimeoutError; use zenoh::{sample::Locality, Result as ZResult, Session}; use zenoh_config::{ModeDependentValue, WhatAmI}; use zenoh_core::ztimeout; @@ -59,7 +58,9 @@ async fn zenoh_matching_status_any() -> ZResult<()> { let matching_listener = ztimeout!(publisher1.matching_listener()).unwrap(); let received_status = matching_listener.recv_timeout(RECV_TIMEOUT); - assert!(received_status.err() == Some(RecvTimeoutError::Timeout)); + assert!( + received_status.err().unwrap().downcast_ref() == Some(&flume::RecvTimeoutError::Timeout) + ); let matching_status = ztimeout!(publisher1.matching_status()).unwrap(); assert!(!matching_status.matching_subscribers()); @@ -113,7 +114,9 @@ async fn zenoh_matching_status_remote() -> ZResult<()> { let matching_listener = ztimeout!(publisher1.matching_listener()).unwrap(); let received_status = matching_listener.recv_timeout(RECV_TIMEOUT); - assert!(received_status.err() == Some(RecvTimeoutError::Timeout)); + assert!( + received_status.err().unwrap().downcast_ref() == Some(&flume::RecvTimeoutError::Timeout) + ); let matching_status = ztimeout!(publisher1.matching_status()).unwrap(); assert!(!matching_status.matching_subscribers()); @@ -121,7 +124,9 @@ async fn zenoh_matching_status_remote() -> ZResult<()> { let sub = ztimeout!(session1.declare_subscriber("zenoh_matching_status_remote_test")).unwrap(); let received_status = matching_listener.recv_timeout(RECV_TIMEOUT); - assert!(received_status.err() == Some(RecvTimeoutError::Timeout)); + assert!( + received_status.err().unwrap().downcast_ref() == Some(&flume::RecvTimeoutError::Timeout) + ); let matching_status = ztimeout!(publisher1.matching_status()).unwrap(); assert!(!matching_status.matching_subscribers()); @@ -129,7 +134,9 @@ async fn zenoh_matching_status_remote() -> ZResult<()> { ztimeout!(sub.undeclare()).unwrap(); let received_status = matching_listener.recv_timeout(RECV_TIMEOUT); - assert!(received_status.err() == Some(RecvTimeoutError::Timeout)); + assert!( + received_status.err().unwrap().downcast_ref() == Some(&flume::RecvTimeoutError::Timeout) + ); let matching_status = ztimeout!(publisher1.matching_status()).unwrap(); assert!(!matching_status.matching_subscribers()); @@ -168,7 +175,9 @@ async fn zenoh_matching_status_local() -> ZResult<()> { let matching_listener = ztimeout!(publisher1.matching_listener()).unwrap(); let received_status = matching_listener.recv_timeout(RECV_TIMEOUT); - assert!(received_status.err() == Some(RecvTimeoutError::Timeout)); + assert!( + received_status.err().unwrap().downcast_ref() == Some(&flume::RecvTimeoutError::Timeout) + ); let matching_status = ztimeout!(publisher1.matching_status()).unwrap(); assert!(!matching_status.matching_subscribers()); @@ -192,7 +201,9 @@ async fn zenoh_matching_status_local() -> ZResult<()> { let sub = ztimeout!(session2.declare_subscriber("zenoh_matching_status_local_test")).unwrap(); let received_status = matching_listener.recv_timeout(RECV_TIMEOUT); - assert!(received_status.err() == Some(RecvTimeoutError::Timeout)); + assert!( + received_status.err().unwrap().downcast_ref() == Some(&flume::RecvTimeoutError::Timeout) + ); let matching_status = ztimeout!(publisher1.matching_status()).unwrap(); assert!(!matching_status.matching_subscribers()); @@ -200,7 +211,9 @@ async fn zenoh_matching_status_local() -> ZResult<()> { ztimeout!(sub.undeclare()).unwrap(); let received_status = matching_listener.recv_timeout(RECV_TIMEOUT); - assert!(received_status.err() == Some(RecvTimeoutError::Timeout)); + assert!( + received_status.err().unwrap().downcast_ref() == Some(&flume::RecvTimeoutError::Timeout) + ); let matching_status = ztimeout!(publisher1.matching_status()).unwrap(); assert!(!matching_status.matching_subscribers());