From bd98d01f484a84f77f852aa56cc4bcd0cece1b1d Mon Sep 17 00:00:00 2001 From: Jonathan Johnson Date: Mon, 20 Jan 2025 13:06:01 -0800 Subject: [PATCH] Added broadcast channel Also flushed out the APIs in general. --- src/reactive.rs | 28 +- src/reactive/channel.rs | 1709 ++++++++++++++++++++++++------- src/reactive/channel/builder.rs | 341 ++++++ 3 files changed, 1703 insertions(+), 375 deletions(-) create mode 100644 src/reactive/channel/builder.rs diff --git a/src/reactive.rs b/src/reactive.rs index 182a1764b..357376a38 100644 --- a/src/reactive.rs +++ b/src/reactive.rs @@ -191,6 +191,12 @@ struct Futures { } impl Futures { + fn spawn(&mut self, future: PollChannelFuture) -> usize { + let id = self.push(future); + self.queue.push_back(id); + id + } + fn push(&mut self, future: PollChannelFuture) -> usize { let mut id = None; while !self.available.is_empty() { @@ -228,7 +234,8 @@ impl Futures { let mut ctx = Context::from_waker(®istered.waker); match Pin::new(future).poll(&mut ctx) { Poll::Ready(()) => { - self.registered.remove(id); + registered.future = None; + self.available.insert(id); callbacks_executed += 1; } Poll::Pending => {} @@ -305,7 +312,9 @@ impl CallbackExecutor { self.channels.notify(id, &mut self.futures); } ChannelTask::Unregister(id) => { - self.channels.unregister(id); + if let Some(future_id) = self.channels.unregister(id) { + self.futures.wake(future_id); + } } }, BackgroundTask::ExecuteCallbacks(callbacks) => { @@ -334,7 +343,7 @@ impl WatchedChannels { return; }; let future_id = channel.should_poll().then(|| { - futures.push(PollChannelFuture { + futures.spawn(PollChannelFuture { channel: channel.clone(), futures: Vec::new(), }) @@ -364,11 +373,11 @@ impl WatchedChannels { .push_back(channel.future_id.expect("initialized above")); } - fn unregister(&mut self, id: usize) { - let Some(id) = self.by_id.remove(&id) else { - return; - }; - self.registry.remove(id); + fn unregister(&mut self, id: usize) -> Option { + let id = self.by_id.remove(&id)?; + self.registry + .remove(id) + .and_then(|removed| removed.future_id) } } @@ -388,6 +397,7 @@ impl Future for PollChannelFuture { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = &mut *self; if this.futures.is_empty() && !this.channel.poll(&mut this.futures) { + this.channel.disconnect(); return Poll::Ready(()); } loop { @@ -399,7 +409,7 @@ impl Future for PollChannelFuture { match result { Ok(()) => {} Err(CallbackDisconnected) => { - self.channel.disconnect_callback(); + self.channel.disconnect(); } } completed_one = true; diff --git a/src/reactive/channel.rs b/src/reactive/channel.rs index 74e3df4c7..113bb1dff 100644 --- a/src/reactive/channel.rs +++ b/src/reactive/channel.rs @@ -1,55 +1,194 @@ -//! A reactive multi-sender, single-consumer (mpsc) channel for Cushy. - +//! Reactive channels for Cushy +//! +//! Channels ensure that every message sent is delivered to a receiver. Dynamics +//! contain values and can provide reactivity, but if a dynamic is updated more +//! quickly than the change callbacks are executed, it is possible for change +//! callbacks to not observe every value stored in the Dynamic. Channels allow +//! building data flows that ensure every value written is observed. +//! +//! Cushy supports two types of channels: +//! +//! - Multi-Producer, Single-Consumer (MPSC): One or more [`Sender`]s send +//! values to either a [`Receiver`] or a callback function. +//! +//! Created by: +//! +//! - [`unbounded()`] +//! - [`bounded()`] +//! - [`build()`]`.finish()`/`build().bound(capacity).finish()` +//! +//! - Broadcast: A [`Broadcaster`] or a [`BroadcastChannel`] sends values +//! to one or more callback functions. This type requires `T` to implement +//! `Clone` as each callback receives its own clone of the value being +//! broadcast. +//! +//! Broadcast channels ensure every callback associated is completed for each +//! value received before receiving the next value. +//! +//! Created by: +//! - [`BroadcastChannel::unbounded()`] +//! - [`BroadcastChannel::bounded()`] +//! - [`build()`]`broadcasting().finish()`/`build().bound(capacity).broadcasting().finish()` +//! +//! All channel types support being either unbounded or bounded. An unbounded +//! channel dynamically allocates its queue and grows as needed. It can cause +//! unexpected memory use or panics if the queues grow too large for the +//! available system memory. Bounded channels allocate a buffer of a known +//! capacity and can block on send or return errors when the queue is full. +//! +//! One of the features provided by Cushy's channels are the abilility to attach +//! callbacks to be executed when values are sent. Instead of needing to +//! manually spawn threads or async tasks, these callbacks are automatically +//! scheduled by Cushy, making channel reactivity feel similar to +//! [`Dynamic`](crate::value::Dynamic) reactivity. However, channels +//! guarantee that the callbacks associated with them receive *every* value +//! written, while dynamics only guarantee that the latest written value will be +//! observed. +//! +//! # Blocking callbacks +//! +//! When a callback might block while waiting on another thread, a network task, +//! or some other operation that may take a long time or require synchronization +//! that could block (e.g., mutexes), it should be considered a *blocking* +//! callback. Each blocking callback is executed in a way that ensures it cannot +//! block any other operation while waiting for new values to be sent. +//! +//! These callbacks can be configured using: +//! +//! - [`Receiver::on_receive`] +//! - [`BroadcastChannel::on_receive`] +//! - [`Builder::on_receive`] +//! +//! # Non-blocking callbacks +//! +//! When a callback will never block for a significant amount of time or in a +//! way that depends on other threads or callbacks or external resources, a +//! non-blocking callback can be used. These callbacks are executed in a shared +//! execution environment that minimizes resource consumption compared to what +//! is required to execute blocking callbacks. +//! +//! These callbacks can be configured using: +//! +//! - [`Receiver::on_receive_nonblocking`] +//! - [`BroadcastChannel::on_receive_nonblocking`] +//! - [`Builder::on_receive_nonblocking`] +//! +//! # Async callbacks +//! +//! If a callback needs to `await` a future, an async callback can be used. +//! These callbacks are functions that take a value and return a future that can +//! be awaited to process the value. The future returned is awaited to +//! completion before the next value is received from the channel. +//! +//! These callbacks can be configured using: +//! +//! - [`Receiver::on_receive_async`] +//! - [`BroadcastChannel::on_receive_async`] +//! - [`Builder::on_receive_async`] use std::collections::VecDeque; use std::fmt::{self, Debug}; use std::future::Future; use std::ops::ControlFlow; use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll, Waker}; +use std::sync::{mpsc, Arc}; +use std::task::{ready, Context, Poll, Waker}; +use builder::Builder; use parking_lot::{Condvar, Mutex, MutexGuard}; +use sealed::{AnyChannelCallback, AsyncCallbackFuture, CallbackKind, ChannelCallbackError}; use crate::reactive::{enqueue_task, BackgroundTask, ChannelTask}; use crate::value::CallbackDisconnected; -/// An error occurred while trying to send a value. +pub mod builder; + +/// Returns multi-producer, single-consumer channel with no limit to the number +/// of values enqueued. +#[must_use] +pub fn unbounded() -> (Sender, Receiver) +where + T: Send + 'static, +{ + Builder::new().finish() +} + +/// Returns multi-producer, single-consumer channel that limits queued values to +/// `capacity` items. +#[must_use] +pub fn bounded(capacity: usize) -> (Sender, Receiver) +where + T: Send + 'static, +{ + Builder::new().bounded(capacity).finish() +} + +/// Returns a [`Builder`] for a Cushy channel. +pub fn build() -> Builder { + Builder::default() +} + +mod sealed { + use std::future::Future; + use std::pin::Pin; + + pub enum CallbackKind { + Blocking(Box Result<(), super::CallbackDisconnected> + Send + 'static>), + NonBlocking(Box>), + } + + pub trait AnyChannelCallback: Send + 'static { + fn invoke(&mut self, value: T) -> Result<(), ChannelCallbackError>; + } + + pub enum ChannelCallbackError { + Async(AsyncCallbackFuture), + Disconnected, + } + + pub type AsyncCallbackFuture = + Pin>>>; +} + +/// An error occurred while trying to send a value to a channel. pub enum TrySendError { - /// The recipient was full. + /// The channel was full. Full(T), - /// The recipient is no longer reachable. + /// The channel no longer has any associated behaviors or receivers. Disconnected(T), } -/// A future that sends a message to a [`Channel`]. +/// A future that sends a value to a [channel](self). #[must_use = "Futures must be awaited to be executed"] pub struct SendAsync<'a, T> { value: Option, - channel: &'a Channel, + channel: &'a Sender, } impl Future for SendAsync<'_, T> where - T: Unpin + Clone + Send + 'static, + T: Unpin + Send + 'static, { - type Output = Option; + type Output = Result<(), T>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let Some(value) = self.value.take() else { - return Poll::Ready(None); + return Poll::Ready(Ok(())); }; - match self.channel.try_send_inner(value, |channel| { - let will_wake = channel - .wakers - .iter() - .any(|waker| waker.will_wake(cx.waker())); - if !will_wake { - channel.wakers.push(cx.waker().clone()); - } - ControlFlow::Break(()) - }) { - Ok(()) => Poll::Ready(None), - Err(TrySendError::Disconnected(value)) => Poll::Ready(Some(value)), + match self + .channel + .data + .try_send_inner(value, channel_id(&self.channel.data), |channel| { + let will_wake = channel + .wakers + .iter() + .any(|waker| waker.will_wake(cx.waker())); + if !will_wake { + channel.wakers.push(cx.waker().clone()); + } + ControlFlow::Break(()) + }) { + Ok(()) => Poll::Ready(Ok(())), + Err(TrySendError::Disconnected(value)) => Poll::Ready(Err(value)), Err(TrySendError::Full(value)) => { self.value = Some(value); Poll::Pending @@ -61,23 +200,24 @@ where pub(super) trait AnyChannel: Send + Sync + 'static { fn should_poll(&self) -> bool; fn poll(&self, futures: &mut Vec) -> bool; - fn disconnect_callback(&self); + fn disconnect(&self); } -impl AnyChannel for ChannelData +impl AnyChannel for Arc> where T: Send + 'static, + Behavior: CallbackBehavior, { fn should_poll(&self) -> bool { let channel = self.synced.lock(); - !channel.queue.is_empty() && channel.callback.is_some() + !channel.queue.is_empty() && channel.behavior.connected() } fn poll(&self, futures: &mut Vec) -> bool { let mut channel = self.synced.lock(); let Some(value) = channel.queue.pop_front() else { - return channel.instances > 0 && channel.callback.is_some(); + return channel.senders > 0 && channel.behavior.connected(); }; self.condvar.notify_all(); @@ -85,26 +225,27 @@ where waker.wake(); } - if let Some(callback) = &mut channel.callback { - match callback.invoke(value) { - Ok(()) => {} - Err(ChannelCallbackError::Async(future)) => { - futures.push(ChannelCallbackFuture { - future: Pin::from(future), - }); - } - Err(ChannelCallbackError::Disconnected) => { - channel.callback = None; - return false; - } + match channel.behavior.invoke(value, self) { + Ok(()) => {} + Err(ChannelCallbackError::Async(future)) => { + futures.push(ChannelCallbackFuture { future }); + } + Err(ChannelCallbackError::Disconnected) => { + return false; } } true } - fn disconnect_callback(&self) { - self.synced.lock().callback = None; + fn disconnect(&self) { + let mut data = self.synced.lock(); + data.behavior.disconnect(); + for waker in data.wakers.drain(..) { + waker.wake(); + } + drop(data); + self.condvar.notify_all(); } } @@ -112,239 +253,47 @@ pub(super) struct ChannelCallbackFuture { pub(super) future: Pin>>>, } -/// A reactive multi-sender, single-consumer (mpsc) channel that executes code -/// in the background when values are received. -/// -/// A [`Dynamic`](crate::value::Dynamic) is a container for a `T` that can be -/// reacted against. Due to this design, it is possible to not observe *every* -/// value that passes through the container. For some use cases (such as the -/// [Command Pattern][command]), it is important that every value is observed. -/// -/// This type ensures the associated code is executed for each value sent. -/// Additionally, unlike other channel types, this code is scheduled to be -/// executed by Cushy automatically instead of requiring additional threads or -/// an external async runtime. -/// -/// [command]: https://en.wikipedia.org/wiki/Command_pattern +/// A sender of values to a [channel](self). #[derive(Debug)] -pub struct Channel { - data: Arc>, +pub struct Sender { + data: Arc>>, } -impl Channel +impl Sender where T: Send + 'static, { - /// Returns a channel that executes `on_receive` for each value sent. - /// - /// The returned channel will never be considered full and will panic if a - /// large enough queue cannot be allocated. - #[must_use] - pub fn unbounded(mut on_receive: F) -> Self - where - F: FnMut(T) + Send + 'static, - { - Self::new( - None, - Box::new(move |value| { - on_receive(value); - Ok(()) - }), - ) - } - - /// Returns a channel that executes `on_receive` for each value sent. The - /// channel will be disconnected if the callback returns - /// `Err(CallbackDisconnected)`. - /// - /// The returned channel will never be considered full and will panic if a - /// large enough queue cannot be allocated. - #[must_use] - pub fn unbounded_try(mut on_receive: F) -> Self - where - F: FnMut(T) -> Result<(), CallbackDisconnected> + Send + 'static, - { - Self::new( - None, - Box::new(move |value| { - on_receive(value).map_err(|_| ChannelCallbackError::Disconnected) - }), - ) - } - - /// Returns a channel that executes the future returned from `on_receive` - /// for each value sent. - /// - /// The returned channel will never be considered full and will panic if a - /// large enough queue cannot be allocated. - #[must_use] - pub fn unbounded_async(mut on_receive: F) -> Self - where - F: FnMut(T) -> Fut + Send + 'static, - Fut: Future + Send + 'static, - { - Self::new( - None, - Box::new(move |value| { - let future = on_receive(value); - Err(ChannelCallbackError::Async(Box::new(async move { - future.await; - Ok(()) - }))) - }), - ) - } - - /// Returns a channel that executes the future returned from `on_receive` - /// for each value sent. The channel will be disconnected if the callback - /// returns `Err(CallbackDisconnected)`. - /// - /// The returned channel will never be considered full and will panic if a - /// large enough queue cannot be allocated. - #[must_use] - pub fn unbounded_async_try(mut on_receive: F) -> Self - where - F: FnMut(T) -> Fut + Send + 'static, - Fut: Future> + Send + 'static, - { - Self::new( - None, - Box::new(move |value| Err(ChannelCallbackError::Async(Box::new(on_receive(value))))), - ) - } - - /// Returns a bounded channel that executes `on_receive` for each value - /// sent. - /// - /// The returned channel will only allow `capacity` values to be queued at - /// any moment in time. Each `send` function documents what happens when the - /// channel is full. - #[must_use] - pub fn bounded(capacity: usize, mut on_receive: F) -> Self - where - F: FnMut(T) + Send + 'static, - { - Self::new( - Some(capacity), - Box::new(move |value| { - on_receive(value); - Ok(()) - }), - ) - } - - /// Returns a bounded channel that executes `on_receive` for each value - /// sent. The channel will be disconnected if the callback returns - /// `Err(CallbackDisconnected)`. - /// - /// The returned channel will only allow `capacity` values to be queued at - /// any moment in time. Each `send` function documents what happens when the - /// channel is full. - #[must_use] - pub fn bounded_try(capacity: usize, mut on_receive: F) -> Self - where - F: FnMut(T) -> Result<(), CallbackDisconnected> + Send + 'static, - { - Self::new( - Some(capacity), - Box::new(move |value| { - on_receive(value).map_err(|_| ChannelCallbackError::Disconnected) - }), - ) - } - - /// Returns a bounded channel that executes the future returned from - /// `on_receive` for each value sent. - /// - /// The returned channel will only allow `capacity` values to be queued at - /// any moment in time. Each `send` function documents what happens when the - /// channel is full. - #[must_use] - pub fn bounded_async(capacity: usize, mut on_receive: F) -> Self - where - F: FnMut(T) -> Fut + Send + 'static, - Fut: Future + Send + 'static, - { - Self::new( - Some(capacity), - Box::new(move |value| { - let future = on_receive(value); - Err(ChannelCallbackError::Async(Box::new(async move { - future.await; - Ok(()) - }))) - }), - ) - } - - /// Returns a bounded channel that executes the future returned from `on_receive` - /// for each value sent. The channel will be disconnected if the callback - /// returns `Err(CallbackDisconnected)`. - /// - /// The returned channel will only allow `capacity` values to be queued at - /// any moment in time. Each `send` function documents what happens when the - /// channel is full. - #[must_use] - pub fn bounded_async_try(capacity: usize, mut on_receive: F) -> Self - where - F: FnMut(T) -> Fut + Send + 'static, - Fut: Future> + Send + 'static, - { - Self::new( - Some(capacity), - Box::new(move |value| Err(ChannelCallbackError::Async(Box::new(on_receive(value))))), - ) - } - - fn new(limit: Option, callback: Box>) -> Self { - let (queue, limit) = match limit { - Some(limit) => (VecDeque::with_capacity(limit), limit), - None => (VecDeque::new(), usize::MAX), - }; - let this = Self { - data: Arc::new(ChannelData { - condvar: Condvar::new(), - synced: Mutex::new(SyncedChannelData { - queue, - limit, - instances: 1, - wakers: Vec::new(), - callback: Some(callback), - }), - }), - }; - enqueue_task(BackgroundTask::Channel(ChannelTask::Register { - id: this.id(), - data: this.data.clone(), - })); - this - } - /// Sends `value` to this channel. /// - /// Returns `Some(value)` if the channel is disconnected. - /// /// If the channel is full, this function will block the current thread - /// until space is made available. If another Channel's `on_receive` is - /// sending a value to a bounded channel, that `on_receive` should be async - /// and use [`send_async()`](Self::send_async) instead. - #[allow(clippy::must_use_candidate)] - pub fn send(&self, value: T) -> Option { - match self.try_send_inner(value, |channel| { - self.data.condvar.wait(channel); - ControlFlow::Continue(()) - }) { - Ok(()) => None, - Err(TrySendError::Disconnected(value) | TrySendError::Full(value)) => Some(value), + /// until space is made available. If one channel's `on_receive` is sending + /// a value to a bounded channel, that `on_receive` should be + /// `on_receive_async` instead and use [`send_async()`](Self::send_async). + /// + /// # Errors + /// + /// Returns `Err(value)` if the channel is disconnected. + pub fn send(&self, value: T) -> Result<(), T> { + match self + .data + .try_send_inner(value, channel_id(&self.data), |channel| { + self.data.condvar.wait(channel); + ControlFlow::Continue(()) + }) { + Ok(()) => Ok(()), + Err(TrySendError::Disconnected(value) | TrySendError::Full(value)) => Err(value), } } /// Sends `value` to this channel asynchronously. /// - /// The future returns `Some(value)` if the channel is disconnected. - /// /// If the channel is full, this future will wait until space is made /// available before sending. + /// + /// # Errors + /// + /// The returned future will return `Err(value)` if the channel is + /// disconnected. pub fn send_async(&self, value: T) -> SendAsync<'_, T> { SendAsync { value: Some(value), @@ -362,156 +311,1184 @@ where /// - When the channel is full, [`TrySendError::Full`] will /// be returned. pub fn try_send(&self, value: T) -> Result<(), TrySendError> { - self.try_send_inner(value, |_| ControlFlow::Break(())) + self.data + .try_send_inner(value, channel_id(&self.data), |_| ControlFlow::Break(())) } +} - fn try_send_inner( - &self, - value: T, - mut when_full: impl FnMut(&mut MutexGuard<'_, SyncedChannelData>) -> ControlFlow<()>, - ) -> Result<(), TrySendError> { +impl Clone for Sender { + fn clone(&self) -> Self { let mut channel = self.data.synced.lock(); - while channel.callback.is_some() { - if channel.queue.len() >= channel.limit { - match when_full(&mut channel) { - ControlFlow::Continue(()) => continue, - ControlFlow::Break(()) => return Err(TrySendError::Full(value)), + channel.senders += 1; + Self { + data: self.data.clone(), + } + } +} + +impl Drop for Sender { + fn drop(&mut self) { + let mut channel = self.data.synced.lock(); + channel.senders -= 1; + + if channel.senders == 0 { + match &channel.behavior { + SingleCallback::Receiver => { + drop(channel); + self.data.condvar.notify_all(); } + SingleCallback::Callback(_) => { + drop(channel); + enqueue_task(BackgroundTask::Channel(ChannelTask::Unregister( + channel_id(&self.data), + ))); + } + SingleCallback::Disconnected => {} } - let should_notify = channel.queue.is_empty(); - channel.queue.push_back(value); - drop(channel); + } + } +} - if should_notify { - enqueue_task(BackgroundTask::Channel(ChannelTask::Notify { - id: self.id(), - })); - } +enum SingleCallback { + Receiver, + Callback(Box>), + Disconnected, +} - return Ok(()); +impl CallbackBehavior for SingleCallback +where + T: Send + 'static, +{ + fn connected(&self) -> bool { + !matches!(self, Self::Disconnected) + } + + fn disconnect(&mut self) { + *self = Self::Disconnected; + } + + fn invoke( + &mut self, + value: T, + _channel: &Arc>, + ) -> Result<(), ChannelCallbackError> { + let cb = match self { + SingleCallback::Receiver => unreachable!("callback installed without callback"), + SingleCallback::Callback(cb) => cb, + SingleCallback::Disconnected => return Err(ChannelCallbackError::Disconnected), + }; + + match cb.invoke(value) { + Err(ChannelCallbackError::Disconnected) => { + *self = Self::Disconnected; + Err(ChannelCallbackError::Disconnected) + } + other => other, } - Err(TrySendError::Disconnected(value)) } } -impl Channel { - fn id(&self) -> usize { - Arc::as_ptr(&self.data) as usize +impl Debug for SingleCallback { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SingleCallback::Receiver => f.write_str("0 callbacks"), + SingleCallback::Callback(_) => f.write_str("1 callback"), + SingleCallback::Disconnected => f.write_str("disconnected"), + } } } -impl Clone for Channel { - fn clone(&self) -> Self { - let mut channel = self.data.synced.lock(); - channel.instances += 1; - Self { - data: self.data.clone(), +enum BroadcastCallback { + Blocking { + sender: mpsc::SyncSender<(T, Waker)>, + result: mpsc::Receiver<()>, + }, + NonBlocking(Box>), +} + +impl BroadcastCallback { + fn spawn_blocking( + mut cb: Box Result<(), super::CallbackDisconnected> + Send + 'static>, + ) -> Self + where + T: Send + 'static, + { + let (value_sender, value_receiver) = mpsc::sync_channel::<(T, Waker)>(1); + let (result_sender, result_receiver) = mpsc::sync_channel(1); + std::thread::spawn(move || { + while let Ok((value, waker)) = value_receiver.recv() { + if let Ok(()) = cb(value) { + if result_sender.send(()).is_err() { + return; + } + waker.wake(); + } else { + drop(result_sender); + waker.wake(); + return; + } + } + }); + Self::Blocking { + sender: value_sender, + result: result_receiver, } } } -impl Drop for Channel { - fn drop(&mut self) { - let mut channel = self.data.synced.lock(); - channel.instances -= 1; +struct MultipleCallbacks(Vec>); - if channel.instances == 0 { - drop(channel); - enqueue_task(BackgroundTask::Channel(ChannelTask::Unregister(self.id()))); +impl CallbackBehavior for MultipleCallbacks +where + T: Unpin + Clone + Send + 'static, +{ + fn connected(&self) -> bool { + !self.0.is_empty() + } + + fn disconnect(&mut self) { + self.0.clear(); + } + + fn invoke( + &mut self, + value: T, + channel: &Arc>, + ) -> Result<(), ChannelCallbackError> { + let mut sent_one = false; + + let mut i = 0; + let mut value = TakeN::new(value, self.0.len()); + while i < self.0.len() { + match &mut self.0[i] { + BroadcastCallback::Blocking { .. } => { + return Err(ChannelCallbackError::Async(Box::pin(BroadcastSend { + value, + sent_one, + data: channel.clone(), + current_recipient_future: None, + current_is_blocking: false, + next_recipient: i, + }))) + } + BroadcastCallback::NonBlocking(cb) => { + match cb.invoke(value.next().expect("enough value clones")) { + Ok(()) => { + sent_one = true; + } + Err(ChannelCallbackError::Disconnected) => { + self.0.remove(i); + continue; + } + Err(ChannelCallbackError::Async(future)) => { + return Err(ChannelCallbackError::Async(Box::pin(BroadcastSend { + value, + sent_one, + data: channel.clone(), + current_recipient_future: Some(future), + current_is_blocking: false, + next_recipient: i + 1, + }))) + } + } + } + } + + i += 1; + } + + if sent_one { + Ok(()) + } else { + Err(ChannelCallbackError::Disconnected) } } } -#[derive(Debug)] -struct ChannelData { - condvar: Condvar, - synced: Mutex>, +impl Debug for MultipleCallbacks { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.0.len() == 1 { + f.write_str("1 callback") + } else { + write!(f, "{} callbacks", self.0.len()) + } + } } -struct SyncedChannelData { - queue: VecDeque, - limit: usize, - instances: usize, - wakers: Vec, +struct TakeN { + value: Option, + remaining: usize, +} - callback: Option>>, +impl TakeN { + fn new(value: T, count: usize) -> Self { + Self { + value: Some(value), + remaining: count, + } + } } -impl Debug for SyncedChannelData +impl Iterator for TakeN where - T: Debug, + T: Clone, { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("SyncedChannelData") - .field("queue", &self.queue) - .field("limit", &self.limit) - .field("instances", &self.instances) - .field("wakers", &self.wakers) - .field( - "callback", - &if self.callback.is_some() { - "(Connected)" - } else { - "(Disconnected)" - }, - ) - .finish() + type Item = T; + + fn next(&mut self) -> Option { + self.remaining = self.remaining.saturating_sub(1); + if self.remaining > 0 { + self.value.clone() + } else { + self.value.take() + } } } -trait AnyChannelCallback: Send + 'static { - fn invoke(&mut self, value: T) -> Result<(), ChannelCallbackError>; +struct BroadcastSend { + sent_one: bool, + value: TakeN, + next_recipient: usize, + data: Arc>>, + current_recipient_future: Option, + current_is_blocking: bool, } -impl AnyChannelCallback for F +impl BroadcastSend where - F: FnMut(T) -> Result<(), ChannelCallbackError> + Send + 'static, + T: Unpin + Clone + Send + 'static, { - fn invoke(&mut self, value: T) -> Result<(), ChannelCallbackError> { - self(value) + fn poll_tasks(&mut self, cx: &mut Context<'_>) -> Poll<()> { + if let Some(future) = &mut self.current_recipient_future { + match ready!(future.as_mut().poll(cx)) { + Ok(()) => { + self.current_recipient_future = None; + self.sent_one = true; + } + Err(CallbackDisconnected) => { + self.current_recipient_future = None; + } + } + } else if self.current_is_blocking { + let mut data = self.data.synced.lock(); + let BroadcastCallback::Blocking { result, .. } = &data.behavior.0[self.next_recipient] + else { + unreachable!("valid state"); + }; + match result.try_recv() { + Ok(()) => { + self.next_recipient += 1; + } + Err(mpsc::TryRecvError::Empty) => return Poll::Pending, + Err(mpsc::TryRecvError::Disconnected) => { + data.behavior.0.remove(self.next_recipient); + } + } + self.current_is_blocking = false; + } + Poll::Ready(()) } } -enum ChannelCallbackError { - Async(Box>>), - Disconnected, -} +impl Future for BroadcastSend +where + T: Unpin + Clone + Send + 'static, +{ + type Output = Result<(), CallbackDisconnected>; -#[test] -fn channel_basics() { - use crate::value::{Destination, Dynamic, Source}; - let result = Dynamic::new(0); - let result_reader = result.create_reader(); - let channel = Channel::::unbounded(move |value| result.set(dbg!(value))); - assert!(!result_reader.has_updated()); - channel.send(1); - result_reader.block_until_updated(); - assert_eq!(result_reader.get(), 1); -} + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = &mut *self; + ready!(this.poll_tasks(cx)); -#[test] -fn async_channels() { - use crate::value::{Destination, Dynamic, Source}; + let mut data_mutex = this.data.synced.lock(); + loop { + let data = &mut *data_mutex; + if let Some(cb) = data.behavior.0.get_mut(this.next_recipient) { + match cb { + BroadcastCallback::Blocking { sender, .. } => { + if let Ok(()) = sender.send(( + this.value.next().expect("enough value clones"), + cx.waker().clone(), + )) { + this.current_is_blocking = true; + drop(data_mutex); - let result = Dynamic::new(0); - let result_reader = result.create_reader(); - let channel2 = Channel::::unbounded_async(move |value| { - let result = result.clone(); - async move { - result.set(dbg!(value)); - } - }); - let channel1 = Channel::::unbounded_async(move |value| { - let channel2 = channel2.clone(); - async move { - channel2.send(dbg!(value)); + ready!(this.poll_tasks(cx)); + + data_mutex = this.data.synced.lock(); + continue; + } + + data.behavior.0.remove(this.next_recipient); + continue; + } + BroadcastCallback::NonBlocking(cb) => { + match cb.invoke(this.value.next().expect("enough value clones")) { + Ok(()) => { + this.sent_one = true; + } + Err(ChannelCallbackError::Disconnected) => { + data.behavior.0.remove(this.next_recipient); + continue; + } + Err(ChannelCallbackError::Async(future)) => { + this.current_recipient_future = Some(future); + drop(data_mutex); + + ready!(this.poll_tasks(cx)); + + data_mutex = this.data.synced.lock(); + } + } + } + } + + this.next_recipient += 1; + } else if this.sent_one { + return Poll::Ready(Ok(())); + } else { + for waker in data.wakers.drain(..) { + waker.wake(); + } + drop(data_mutex); + this.data.condvar.notify_all(); + return Poll::Ready(Err(CallbackDisconnected)); + } + } + } +} + +trait CallbackBehavior: Sized + Send + 'static { + fn connected(&self) -> bool; + fn disconnect(&mut self); + fn invoke( + &mut self, + value: T, + channel: &Arc>, + ) -> Result<(), ChannelCallbackError>; +} + +#[derive(Debug)] +struct ChannelData { + condvar: Condvar, + synced: Mutex>, +} + +impl ChannelData +where + T: Send + 'static, + Behavior: CallbackBehavior, +{ + fn new( + limit: Option, + behavior: Behavior, + senders: usize, + receivers: usize, + ) -> Arc> { + let (queue, limit) = match limit { + Some(limit) => (VecDeque::with_capacity(limit), limit), + None => (VecDeque::new(), usize::MAX), + }; + let this = Arc::new(ChannelData { + condvar: Condvar::new(), + synced: Mutex::new(SyncedChannelData { + queue, + limit, + senders, + receivers, + wakers: Vec::new(), + behavior, + }), + }); + enqueue_task(BackgroundTask::Channel(ChannelTask::Register { + id: channel_id(&this), + data: Arc::new(this.clone()), + })); + this + } + + fn try_send_inner( + &self, + value: T, + id: usize, + mut when_full: impl FnMut( + &mut MutexGuard<'_, SyncedChannelData>, + ) -> ControlFlow<()>, + ) -> Result<(), TrySendError> { + let mut channel = self.synced.lock(); + while channel.receivers > 0 || channel.behavior.connected() { + if channel.queue.len() >= channel.limit { + match when_full(&mut channel) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(()) => return Err(TrySendError::Full(value)), + } + } + let has_receiver = channel.receivers > 0; + let should_notify = !has_receiver && channel.queue.is_empty(); + channel.queue.push_back(value); + drop(channel); + + if should_notify { + enqueue_task(BackgroundTask::Channel(ChannelTask::Notify { id })); + } else if has_receiver { + self.condvar.notify_all(); + } + + return Ok(()); + } + Err(TrySendError::Disconnected(value)) + } +} + +struct SyncedChannelData { + queue: VecDeque, + limit: usize, + senders: usize, + receivers: usize, + wakers: Vec, + + behavior: Behavior, +} + +impl Debug for SyncedChannelData +where + T: Debug, + Behavior: Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SyncedChannelData") + .field("queue", &self.queue) + .field("limit", &self.limit) + .field("senders", &self.senders) + .field("receiers", &self.receivers) + .field("wakers", &self.wakers) + .field("behavior", &self.behavior) + .finish() + } +} + +/// A channel that broadcasts values received to one or more callbacks. +/// +/// This type represents both a sender and a receiver in terms of determining +/// whether a channel is "connected". This is because at any time additional +/// callbacks can be associated through this type while also allowing values to +/// be sent to already-installed callbacks. +/// +/// Because of this ability to attach future callbacks, a broadcast channel can +/// be created with no associated callbacks. When a value is sent to a channel +/// that has a [`BroadcastChannel`] connected to it, the value will be queued +/// even if no callbacks are currently associated. To prevent this, use +/// [`create_broadcaster()`](Self::create_broadcaster)/[`into_broadcaster()`](Self::into_broadcaster) +/// to create a [`Broadcaster`] for this channel and drop all +/// [`BroadcastChannel`] instances after callbacks have been associated. +pub struct BroadcastChannel { + data: Arc>>, +} + +impl BroadcastChannel +where + T: Unpin + Clone + Send + 'static, +{ + /// Returns broadcast channel with no limit to the number of values + /// enqueued. + #[must_use] + pub fn unbounded() -> Self { + Builder::new().broadcasting().finish() + } + + /// Returns broadcast channel that limits queued values to `capacity` items. + #[must_use] + pub fn bounded(capacity: usize) -> Self { + Builder::new().broadcasting().bounded(capacity).finish() + } + + /// Returns a builder for a broadcast channel. + pub fn build() -> Builder> { + Builder::new().broadcasting() + } + + /// Sends `value` to this channel. + /// + /// If the channel is full, this function will block the current thread + /// until space is made available. If one channel's `on_receive` is sending + /// a value to a bounded channel, that `on_receive` should be + /// `on_receive_async` instead and use [`send_async()`](Self::send_async). + /// + /// # Errors + /// + /// Returns `Err(value)` if the channel is disconnected. + pub fn send(&self, value: T) -> Result<(), T> { + match self + .data + .try_send_inner(value, channel_id(&self.data), |channel| { + self.data.condvar.wait(channel); + ControlFlow::Continue(()) + }) { + Ok(()) => Ok(()), + Err(TrySendError::Disconnected(value) | TrySendError::Full(value)) => Err(value), + } + } + + /// Sends `value` to this channel asynchronously. + /// + /// If the channel is full, this future will wait until space is made + /// available before sending. + /// + /// # Errors + /// + /// The returned future will return `Err(value)` if the channel is + /// disconnected. + pub fn send_async(&self, value: T) -> BroadcastAsync<'_, T> { + BroadcastAsync { + value: Some(value), + channel: &self.data, + } + } + + /// Tries to send `value` to this channel. Returns an error if unable to + /// send. + /// + /// # Errors + /// + /// - When the channel is disconnected, [`TrySendError::Disconnected`] will + /// be returned. + /// - When the channel is full, [`TrySendError::Full`] will + /// be returned. + pub fn try_send(&self, value: T) -> Result<(), TrySendError> { + self.data + .try_send_inner(value, channel_id(&self.data), |_| ControlFlow::Break(())) + } + + /// Returns a [`Broadcaster`] that sends to this channel. + #[must_use] + pub fn create_broadcaster(&self) -> Broadcaster { + let mut data = self.data.synced.lock(); + data.senders += 1; + Broadcaster { + data: self.data.clone(), + } + } + + /// Returns this instance as a [`Broadcaster`] that sends to this channel. + #[must_use] + pub fn into_broadcaster(self) -> Broadcaster { + self.create_broadcaster() + } + + /// Invokes `on_receive` each time a value is sent to this channel. + /// + /// This function assumes `on_receive` may block while waiting on another + /// thread, another process, another callback, a network request, a locking + /// primitive, or any other number of ways that could impact other callbacks + /// from executing. + pub fn on_receive(&self, mut on_receive: Map) + where + Map: FnMut(T) + Send + 'static, + { + self.on_receive_try(move |value| { + on_receive(value); + Ok(()) + }); + } + + /// Invokes `on_receive` each time a value is sent to this channel. Once an + /// error is returned, this callback will be removed from the channel. + /// + /// This function assumes `on_receive` may block while waiting on another + /// thread, another process, another callback, a network request, a locking + /// primitive, or any other number of ways that could impact other callbacks + /// from executing. + /// + /// Once the last callback associated with a channel is removed, [`Sender`]s + /// will begin returning disconnected errors. + pub fn on_receive_try(&self, on_receive: Map) + where + Map: FnMut(T) -> Result<(), CallbackDisconnected> + Send + 'static, + { + self.on_receive_inner(CallbackKind::Blocking(Box::new(on_receive))); + } + + /// Invokes `on_receive` each time a value is sent to this channel. + /// + /// This function assumes `on_receive` will not block while waiting on + /// another thread, another process, another callback, a network request, a + /// locking primitive, or any other number of ways that could impact other + /// callbacks from executing in a shared environment. + pub fn on_receive_nonblocking(&self, mut on_receive: Map) + where + Map: FnMut(T) + Send + 'static, + { + self.on_receive_nonblocking_try(move |value| { + on_receive(value); + Ok(()) + }); + } + + /// Invokes `on_receive` each time a value is sent to this channel. Once an + /// error is returned, this callback will be removed from the channel. + /// + /// This function assumes `on_receive` will not block while waiting on + /// another thread, another process, another callback, a network request, a + /// locking primitive, or any other number of ways that could impact other + /// callbacks from executing in a shared environment. + /// + /// Once the last callback associated with a channel is removed, [`Sender`]s + /// will begin returning disconnected errors. + pub fn on_receive_nonblocking_try(&self, mut on_receive: Map) + where + Map: FnMut(T) -> Result<(), CallbackDisconnected> + Send + 'static, + { + self.on_receive_inner(CallbackKind::NonBlocking(Box::new(move |value| { + on_receive(value).map_err(|CallbackDisconnected| ChannelCallbackError::Disconnected) + }))); + } + + /// Invokes `on_receive` each time a value is sent to this channel. + pub fn on_receive_async(&self, mut on_receive: Map) + where + Map: FnMut(T) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + { + self.on_receive_async_try(move |value| { + let future = on_receive(value); + async move { + future.await; + Ok(()) + } + }); + } + + /// Invokes `on_receive` each time a value is sent to this channel. The + /// returned future will then be awaited. Once an error is returned, this + /// callback will be removed from the channel. + /// + /// Once the last callback associated with a channel is removed, [`Sender`]s + /// will begin returning disconnected errors. + pub fn on_receive_async_try(&self, mut on_receive: Map) + where + Map: FnMut(T) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, + { + self.on_receive_inner(CallbackKind::NonBlocking(Box::new(move |value| { + let future = on_receive(value); + Err(ChannelCallbackError::Async(Box::pin(future))) + }))); + } + + fn on_receive_inner(&self, cb: CallbackKind) { + let mut data = self.data.synced.lock(); + let should_register = data.behavior.0.is_empty(); + match cb { + CallbackKind::Blocking(cb) => { + data.behavior.0.push(BroadcastCallback::spawn_blocking(cb)); + } + CallbackKind::NonBlocking(cb) => { + data.behavior.0.push(BroadcastCallback::NonBlocking(cb)); + } + } + if should_register { + drop(data); + enqueue_task(BackgroundTask::Channel(ChannelTask::Register { + id: channel_id(&self.data), + data: Arc::new(self.data.clone()), + })); + } + } +} + +impl Clone for BroadcastChannel { + fn clone(&self) -> Self { + let mut data = self.data.synced.lock(); + data.senders += 1; + data.receivers += 1; + drop(data); + Self { + data: self.data.clone(), + } + } +} + +impl Drop for BroadcastChannel { + fn drop(&mut self) { + let mut data = self.data.synced.lock(); + data.senders -= 1; + data.receivers -= 1; + + let notify_disconnected = data.senders == 0 || data.behavior.0.is_empty(); + if notify_disconnected { + for waker in data.wakers.drain(..) { + waker.wake(); + } + } + drop(data); + if notify_disconnected { + self.data.condvar.notify_all(); + enqueue_task(BackgroundTask::Channel(ChannelTask::Unregister( + channel_id(&self.data), + ))); + } + } +} + +/// Sends values to a [`BroadcastChannel`]. +pub struct Broadcaster { + data: Arc>>, +} + +impl Broadcaster +where + T: Unpin + Clone + Send + 'static, +{ + /// Sends `value` to this channel. + /// + /// If the channel is full, this function will block the current thread + /// until space is made available. If one channel's `on_receive` is sending + /// a value to a bounded channel, that `on_receive` should be + /// `on_receive_async` instead and use [`send_async()`](Self::send_async). + /// + /// # Errors + /// + /// Returns `Err(value)` if the channel is disconnected. + pub fn send(&self, value: T) -> Result<(), T> { + match self + .data + .try_send_inner(value, channel_id(&self.data), |channel| { + self.data.condvar.wait(channel); + ControlFlow::Continue(()) + }) { + Ok(()) => Ok(()), + Err(TrySendError::Disconnected(value) | TrySendError::Full(value)) => Err(value), + } + } + + /// Sends `value` to this channel asynchronously. + /// + /// If the channel is full, this future will wait until space is made + /// available before sending. + /// + /// # Errors + /// + /// The returned future will return `Err(value)` if the channel is + /// disconnected. + pub fn send_async(&self, value: T) -> BroadcastAsync<'_, T> { + BroadcastAsync { + value: Some(value), + channel: &self.data, + } + } + + /// Tries to send `value` to this channel. Returns an error if unable to + /// send. + /// + /// # Errors + /// + /// - When the channel is disconnected, [`TrySendError::Disconnected`] will + /// be returned. + /// - When the channel is full, [`TrySendError::Full`] will + /// be returned. + pub fn try_send(&self, value: T) -> Result<(), TrySendError> { + self.data + .try_send_inner(value, channel_id(&self.data), |_| ControlFlow::Break(())) + } +} + +impl Clone for Broadcaster { + fn clone(&self) -> Self { + let mut data = self.data.synced.lock(); + data.senders += 1; + drop(data); + Self { + data: self.data.clone(), + } + } +} + +impl Drop for Broadcaster { + fn drop(&mut self) { + let mut data = self.data.synced.lock(); + data.senders -= 1; + + let notify_disconnected = data.senders == 0; + if notify_disconnected { + for waker in data.wakers.drain(..) { + waker.wake(); + } + } + drop(data); + if notify_disconnected { + self.data.condvar.notify_all(); + enqueue_task(BackgroundTask::Channel(ChannelTask::Unregister( + channel_id(&self.data), + ))); + } + } +} + +/// A future that broadcasts a value to a [`BroadcastChannel`]. +#[must_use = "Futures must be awaited to be executed"] +pub struct BroadcastAsync<'a, T> { + value: Option, + channel: &'a Arc>>, +} + +impl Future for BroadcastAsync<'_, T> +where + T: Unpin + Clone + Send + 'static, +{ + type Output = Option; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let Some(value) = self.value.take() else { + return Poll::Ready(None); + }; + match self + .channel + .try_send_inner(value, channel_id(self.channel), |channel| { + let will_wake = channel + .wakers + .iter() + .any(|waker| waker.will_wake(cx.waker())); + if !will_wake { + channel.wakers.push(cx.waker().clone()); + } + ControlFlow::Break(()) + }) { + Ok(()) => Poll::Ready(None), + Err(TrySendError::Disconnected(value)) => Poll::Ready(Some(value)), + Err(TrySendError::Full(value)) => { + self.value = Some(value); + Poll::Pending + } + } + } +} + +impl AnyChannelCallback for F +where + F: FnMut(T) -> Result<(), ChannelCallbackError> + Send + 'static, +{ + fn invoke(&mut self, value: T) -> Result<(), ChannelCallbackError> { + self(value) + } +} + +fn channel_id(data: &Arc>) -> usize { + Arc::as_ptr(data) as usize +} + +/// A receiver for values sent by a [`Sender`]. +pub struct Receiver { + data: Arc>>, +} + +impl Receiver +where + T: Send + 'static, +{ + /// Returns the next value, blocking the current thread until one is + /// available. + /// + /// Returns `None` if there are no [`Sender`]s still connected to this + /// channel. + #[must_use] + pub fn receive(&self) -> Option { + self.try_receive_inner(|guard| { + self.data.condvar.wait(guard); + ControlFlow::Continue(()) + }) + .ok() + } + + /// Returns the next value if possible, otherwise returning an error + /// describing why a value was unable to be received. + /// + /// This function will not block the current thread. + pub fn try_receive(&self) -> Result { + self.try_receive_inner(|_guard| ControlFlow::Break(())) + } + + fn try_receive_inner( + &self, + mut when_empty: impl FnMut( + &mut MutexGuard<'_, SyncedChannelData>>, + ) -> ControlFlow<()>, + ) -> Result { + let mut data = self.data.synced.lock(); + loop { + if let Some(value) = data.queue.pop_front() { + for waker in data.wakers.drain(..) { + waker.wake(); + } + drop(data); + self.data.condvar.notify_all(); + return Ok(value); + } + + if data.senders == 0 { + return Err(TryReceiveError::Disconnected); + } + + if when_empty(&mut data).is_break() { + return Err(TryReceiveError::Empty); + } + } + } + + /// Invokes `on_receive` each time a value is sent to this channel. + /// + /// This function assumes `on_receive` may block while waiting on another + /// thread, another process, another callback, a network request, a locking + /// primitive, or any other number of ways that could impact other callbacks + /// from executing. + pub fn on_receive(self, mut on_receive: Map) + where + Map: FnMut(T) + Send + 'static, + { + self.on_receive_try(move |value| { + on_receive(value); + Ok(()) + }); + } + + /// Invokes `on_receive` each time a value is sent to this channel. Once an + /// error is returned, this callback will be removed from the channel. + /// + /// This function assumes `on_receive` may block while waiting on another + /// thread, another process, another callback, a network request, a locking + /// primitive, or any other number of ways that could impact other callbacks + /// from executing. + /// + /// Once the last callback associated with a channel is removed, [`Sender`]s + /// will begin returning disconnected errors. + pub fn on_receive_try(self, on_receive: Map) + where + Map: FnMut(T) -> Result<(), CallbackDisconnected> + Send + 'static, + { + self.on_receive_inner(CallbackKind::Blocking(Box::new(on_receive))); + } + + /// Invokes `on_receive` each time a value is sent to this channel. + /// + /// This function assumes `on_receive` will not block while waiting on + /// another thread, another process, another callback, a network request, a + /// locking primitive, or any other number of ways that could impact other + /// callbacks from executing in a shared environment. + pub fn on_receive_nonblocking(self, mut on_receive: Map) + where + Map: FnMut(T) + Send + 'static, + { + self.on_receive_nonblocking_try(move |value| { + on_receive(value); + Ok(()) + }); + } + + /// Invokes `on_receive` each time a value is sent to this channel. Once an + /// error is returned, this callback will be removed from the channel. + /// + /// This function assumes `on_receive` will not block while waiting on + /// another thread, another process, another callback, a network request, a + /// locking primitive, or any other number of ways that could impact other + /// callbacks from executing in a shared environment. + /// + /// Once the last callback associated with a channel is removed, [`Sender`]s + /// will begin returning disconnected errors. + pub fn on_receive_nonblocking_try(self, mut on_receive: Map) + where + Map: FnMut(T) -> Result<(), CallbackDisconnected> + Send + 'static, + { + self.on_receive_inner(CallbackKind::NonBlocking(Box::new(move |value| { + on_receive(value).map_err(|CallbackDisconnected| ChannelCallbackError::Disconnected) + }))); + } + + /// Invokes `on_receive` each time a value is sent to this channel. + pub fn on_receive_async(self, mut on_receive: Map) + where + Map: FnMut(T) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + { + self.on_receive_async_try(move |value| { + let future = on_receive(value); + async move { + future.await; + Ok(()) + } + }); + } + + /// Invokes `on_receive` each time a value is sent to this channel. The + /// returned future will then be awaited. Once an error is returned, this + /// callback will be removed from the channel. + /// + /// Once the last callback associated with a channel is removed, [`Sender`]s + /// will begin returning disconnected errors. + pub fn on_receive_async_try(self, mut on_receive: Map) + where + Map: FnMut(T) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, + { + self.on_receive_inner(CallbackKind::NonBlocking(Box::new(move |value| { + let future = on_receive(value); + Err(ChannelCallbackError::Async(Box::pin(future))) + }))); + } + + fn on_receive_inner(self, cb: CallbackKind) { + match cb { + CallbackKind::Blocking(fn_mut) => { + self.spawn_thread(fn_mut); + } + CallbackKind::NonBlocking(cb) => { + let mut data = self.data.synced.lock(); + data.behavior = SingleCallback::Callback(cb); + drop(data); + enqueue_task(BackgroundTask::Channel(ChannelTask::Register { + id: channel_id(&self.data), + data: Arc::new(self.data.clone()), + })); + } + } + } + + fn spawn_thread( + self, + mut cb: Box Result<(), super::CallbackDisconnected> + Send + 'static>, + ) { + std::thread::spawn(move || { + while let Some(value) = self.receive() { + if let Err(CallbackDisconnected) = cb(value) { + return; + } + } + }); + } +} + +impl Future for &Receiver +where + T: Unpin + Send + 'static, +{ + type Output = Option; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.try_receive_inner(|guard| { + let will_wake = guard.wakers.iter().any(|w| w.will_wake(cx.waker())); + if !will_wake { + guard.wakers.push(cx.waker().clone()); + } + ControlFlow::Break(()) + }) { + Ok(value) => Poll::Ready(Some(value)), + Err(TryReceiveError::Disconnected) => Poll::Ready(None), + Err(TryReceiveError::Empty) => Poll::Pending, + } + } +} + +impl Drop for Receiver { + fn drop(&mut self) { + let mut data = self.data.synced.lock(); + data.receivers -= 1; + if matches!(data.behavior, SingleCallback::Receiver) { + data.behavior = SingleCallback::Disconnected; + } + for waker in data.wakers.drain(..) { + waker.wake(); + } + drop(data); + self.data.condvar.notify_all(); + } +} + +/// An error trying to receive a value from a channel. +pub enum TryReceiveError { + /// The channel was empty. + Empty, + /// The channel has no senders connected. + Disconnected, +} + +#[test] +fn channel_basics() { + let (result_sender, result_receiver) = unbounded(); + + let sender = Builder::new() + .on_receive_nonblocking(move |value| result_sender.send(dbg!(value)).unwrap()) + .finish(); + sender.send(1).unwrap(); + + assert_eq!(result_receiver.receive().unwrap(), 1); + drop(sender); + assert_eq!(result_receiver.receive(), None); +} + +#[test] +fn send_then_spawn() { + let (result_sender, result_receiver) = unbounded(); + + let (sender, receiver) = Builder::new().finish(); + sender.send(1).unwrap(); + receiver.on_receive_nonblocking(move |value| result_sender.send(dbg!(value)).unwrap()); + + assert_eq!(result_receiver.receive().unwrap(), 1); + drop(sender); + assert_eq!(result_receiver.receive(), None); +} + +#[test] +fn disconnected_send() { + let (sender, receiver) = Builder::new().finish(); + // Sending is allowed while a receiver could theoretically receive it. + sender.send(1).unwrap(); + drop(receiver); + assert_eq!(sender.send(2), Err(2)); +} + +#[test] +fn broadcast_basic() { + let (result_sender, result_receiver) = unbounded(); + + let channel = Builder::new() + .broadcasting() + .on_receive_nonblocking({ + let result_sender = result_sender.clone(); + move |value| { + result_sender.send(value).unwrap(); + } + }) + .on_receive_nonblocking({ + move |value| { + result_sender.send(value).unwrap(); + } + }) + .finish(); + channel.send(1).unwrap(); + + assert_eq!(result_receiver.receive(), Some(1)); + assert_eq!(result_receiver.receive(), Some(1)); + drop(channel); + assert_eq!(result_receiver.receive(), None); +} + +#[test] +fn async_channels() { + let (a_sender, a_receiver) = bounded(1); + let (b_sender, b_receiver) = bounded(1); + + a_receiver.on_receive_async(move |value| { + let b_sender = b_sender.clone(); + async move { + for i in 0..value { + b_sender.send_async(dbg!(i)).await.unwrap(); + } } }); - assert!(!result_reader.has_updated()); - channel1.send(1); - result_reader.block_until_updated(); - assert_eq!(result_reader.get(), 1); - channel1.send(2); - result_reader.block_until_updated(); - assert_eq!(result_reader.get(), 2); + a_sender.send(5).unwrap(); + for i in 0..5 { + println!("Reading {i}"); + assert_eq!(b_receiver.receive(), Some(i)); + } + drop(a_sender); + assert_eq!(b_receiver.receive(), None); } diff --git a/src/reactive/channel/builder.rs b/src/reactive/channel/builder.rs new file mode 100644 index 000000000..785055d33 --- /dev/null +++ b/src/reactive/channel/builder.rs @@ -0,0 +1,341 @@ +//! Builder types for Cushy [`channel`](super)qs. +use std::future::Future; +use std::marker::PhantomData; + +use super::sealed::{CallbackKind, ChannelCallbackError}; +use super::{ + BroadcastCallback, BroadcastChannel, ChannelData, MultipleCallbacks, Receiver, Sender, +}; +use crate::value::CallbackDisconnected; + +/// Builds a Cushy channel. +/// +/// This type can be used to create all types of channels supported by Cushy. +/// See the [`channel`](self) module documentation for an overview of the +/// channels provided. +#[must_use] +pub struct Builder { + mode: Mode, + ty: PhantomData, + bound: Option, +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} + +impl Builder { + /// Returns a builder for a Cushy channel. + /// + /// The default builder will create an unbounded, Multi-Producer, + /// Single-Consumer channel. See the [`channel`](self) module documentation + /// for an overview of the channels provided. + pub const fn new() -> Self { + Self { + mode: SingleConsumer { _private: () }, + ty: PhantomData, + bound: None, + } + } +} + +impl Builder +where + T: Send + 'static, + Mode: ChannelMode + sealed::ChannelMode>::Next>, +{ + /// Invokes `on_receive` each time a value is sent to this channel. + /// + /// This function assumes `on_receive` may block while waiting on another + /// thread, another process, another callback, a network request, a locking + /// primitive, or any other number of ways that could impact other callbacks + /// from executing. + pub fn on_receive(self, mut on_receive: Map) -> Builder>::Next> + where + Map: FnMut(T) + Send + 'static, + { + self.on_receive_try(move |value| { + on_receive(value); + Ok(()) + }) + } + + /// Invokes `on_receive` each time a value is sent to this channel. Once an + /// error is returned, this callback will be removed from the channel. + /// + /// This function assumes `on_receive` may block while waiting on another + /// thread, another process, another callback, a network request, a locking + /// primitive, or any other number of ways that could impact other callbacks + /// from executing. + /// + /// Once the last callback associated with a channel is removed, [`Sender`]s + /// will begin returning disconnected errors. + pub fn on_receive_try(self, map: Map) -> Builder>::Next> + where + Map: FnMut(T) -> Result<(), CallbackDisconnected> + Send + 'static, + { + Builder { + mode: self + .mode + .push_callback(CallbackKind::Blocking(Box::new(map))), + bound: self.bound, + ty: self.ty, + } + } + + /// Invokes `on_receive` each time a value is sent to this channel. + /// + /// This function assumes `on_receive` will not block while waiting on + /// another thread, another process, another callback, a network request, a + /// locking primitive, or any other number of ways that could impact other + /// callbacks from executing in a shared environment. + pub fn on_receive_nonblocking( + self, + mut on_receive: Map, + ) -> Builder>::Next> + where + Map: FnMut(T) + Send + 'static, + { + self.on_receive_nonblocking_try(move |value| { + on_receive(value); + Ok(()) + }) + } + + /// Invokes `on_receive` each time a value is sent to this channel. Once an + /// error is returned, this callback will be removed from the channel. + /// + /// This function assumes `on_receive` will not block while waiting on + /// another thread, another process, another callback, a network request, a + /// locking primitive, or any other number of ways that could impact other + /// callbacks from executing in a shared environment. + /// + /// Once the last callback associated with a channel is removed, [`Sender`]s + /// will begin returning disconnected errors. + pub fn on_receive_nonblocking_try( + self, + mut map: Map, + ) -> Builder>::Next> + where + Map: FnMut(T) -> Result<(), CallbackDisconnected> + Send + 'static, + { + Builder { + mode: self + .mode + .push_callback(CallbackKind::NonBlocking(Box::new(move |value| { + map(value).map_err(|CallbackDisconnected| ChannelCallbackError::Disconnected) + }))), + bound: self.bound, + ty: self.ty, + } + } + + /// Invokes `on_receive` each time a value is sent to this channel. + pub fn on_receive_async( + self, + mut on_receive: Map, + ) -> Builder>::Next> + where + Map: FnMut(T) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + { + self.on_receive_async_try(move |value| { + let future = on_receive(value); + async move { + future.await; + Ok(()) + } + }) + } + + /// Invokes `on_receive` each time a value is sent to this channel. The + /// returned future will then be awaited. Once an error is returned, this + /// callback will be removed from the channel. + /// + /// Once the last callback associated with a channel is removed, [`Sender`]s + /// will begin returning disconnected errors. + pub fn on_receive_async_try( + self, + mut on_receive: Map, + ) -> Builder>::Next> + where + Map: FnMut(T) -> Fut + Send + 'static, + Fut: Future> + Send + 'static, + { + Builder { + mode: self + .mode + .push_callback(CallbackKind::NonBlocking(Box::new(move |value| { + let future = on_receive(value); + Err(ChannelCallbackError::Async(Box::pin(future))) + }))), + bound: self.bound, + ty: self.ty, + } + } + + /// Returns this builder reconfigured to create a [`BroadcastChannel`]. + /// + /// See the [`channel`](self) module documentation for an overview of the + /// channels provided. + pub fn broadcasting(self) -> Builder> { + Builder { + mode: self.mode.into(), + ty: self.ty, + bound: self.bound, + } + } + + /// Restricts this channel to `capacity` values queued. + pub fn bounded(mut self, capacity: usize) -> Self { + self.bound = Some(capacity); + self + } + + /// Returns the finished channel. + pub fn finish(self) -> Mode::Channel { + self.mode.finish(self.bound) + } +} + +/// Builder configuration for a single-consumer channel with no associated +/// callback. +pub struct SingleConsumer { + _private: (), +} + +impl ChannelMode for SingleConsumer +where + T: Send + 'static, +{ + type Channel = (Sender, Receiver); + type Next = SingleCallback; + + fn finish(self, limit: Option) -> Self::Channel { + let data = ChannelData::new(limit, super::SingleCallback::Receiver, 1, 1); + + (Sender { data: data.clone() }, Receiver { data }) + } +} + +impl sealed::ChannelMode for SingleConsumer { + type Next = SingleCallback; + + fn push_callback(self, cb: CallbackKind) -> Self::Next { + SingleCallback { cb } + } +} + +impl From for Broadcast { + fn from(_: SingleConsumer) -> Self { + Self { + callbacks: Vec::new(), + } + } +} + +/// Builder configuration for a single-consumer channel with an associated +/// callback. +pub struct SingleCallback { + cb: CallbackKind, +} + +impl ChannelMode for SingleCallback +where + T: Send + 'static, +{ + type Channel = Sender; + type Next = Broadcast; + + fn finish(self, limit: Option) -> Self::Channel { + let data = match self.cb { + CallbackKind::Blocking(cb) => { + let data = ChannelData::new(limit, super::SingleCallback::Receiver, 1, 1); + let receiver = Receiver { data: data.clone() }; + receiver.spawn_thread(cb); + data + } + CallbackKind::NonBlocking(cb) => { + ChannelData::new(limit, super::SingleCallback::Callback(cb), 1, 0) + } + }; + Sender { data } + } +} + +impl sealed::ChannelMode for SingleCallback { + type Next = Broadcast; + + fn push_callback(self, cb: CallbackKind) -> Self::Next { + Broadcast { + callbacks: vec![cb], + } + } +} + +impl From> for Broadcast { + fn from(single: SingleCallback) -> Self { + Self { + callbacks: vec![single.cb], + } + } +} + +/// Builder configuration for a [`BroadcastChannel`]. +pub struct Broadcast { + callbacks: Vec>, +} + +impl ChannelMode for Broadcast +where + T: Unpin + Clone + Send + 'static, +{ + type Channel = BroadcastChannel; + type Next = Self; + + fn finish(self, limit: Option) -> Self::Channel { + let callbacks = self + .callbacks + .into_iter() + .map(|cb| match cb { + CallbackKind::Blocking(cb) => BroadcastCallback::spawn_blocking(cb), + CallbackKind::NonBlocking(cb) => BroadcastCallback::NonBlocking(cb), + }) + .collect(); + let data = ChannelData::new(limit, MultipleCallbacks(callbacks), 1, 1); + BroadcastChannel { data } + } +} + +impl sealed::ChannelMode for Broadcast { + type Next = Self; + + fn push_callback(mut self, cb: CallbackKind) -> Self::Next { + self.callbacks.push(cb); + self + } +} + +/// A channel configuration. +pub trait ChannelMode: Into> { + /// The next configuration when a new callback is associated with this + /// builder. + type Next; + /// The resulting channel type created from this configuration. + type Channel; + + /// Returns the built channel. + fn finish(self, limit: Option) -> Self::Channel; +} + +mod sealed { + use crate::channel::sealed::CallbackKind; + + pub trait ChannelMode { + type Next; + + fn push_callback(self, callback: CallbackKind) -> Self::Next; + } +}