From 11c0c6e1d5ed3fb569a4c7c73b9be2ed3073c75c Mon Sep 17 00:00:00 2001 From: Lewin Bormann Date: Wed, 12 Jul 2023 22:42:52 +0200 Subject: [PATCH] Implement simple atomic stream select for #577 --- lib_eio/stream.ml | 46 +++++++++++++++++++++++++++++++++++++++++++++ lib_eio/stream.mli | 4 ++++ lib_eio/waiters.ml | 19 +++++++++++++++---- lib_eio/waiters.mli | 13 +++++++++++-- tests/stream.md | 19 +++++++++++++++++++ 5 files changed, 95 insertions(+), 6 deletions(-) diff --git a/lib_eio/stream.ml b/lib_eio/stream.ml index 974cfa3b7..8a11efc5f 100644 --- a/lib_eio/stream.ml +++ b/lib_eio/stream.ml @@ -94,6 +94,45 @@ module Locking = struct Mutex.unlock t.mutex; Some v + let select_of_many streams_fns = + let finished = Atomic.make false in + let cancel_fns = ref [] in + let add_cancel_fn fn = cancel_fns := fn :: !cancel_fns in + let cancel_all () = List.iter (fun fn -> fn ()) !cancel_fns in + let wait ctx enqueue (t, f) = begin + Mutex.lock t.mutex; + (* First check if any items are already available and return early if there are. *) + if not (Queue.is_empty t.items) + then ( + cancel_all (); + Mutex.unlock t.mutex; + enqueue (Ok (f (Queue.take t.items)))) + else add_cancel_fn @@ + (* Otherwise, register interest in this stream. *) + Waiters.cancellable_await_internal ~mutex:(Some t.mutex) t.readers t.id ctx (fun r -> + if Result.is_ok r then ( + if not (Atomic.compare_and_set finished false true) then ( + (* Another stream has yielded an item in the meantime. However, as + we have been waiting on this stream it must have been empty. + + As the stream's mutex was held since before last checking for an item, + the queue must be empty. + *) + assert ((Queue.length t.items) < t.capacity); + Queue.add (Result.get_ok r) t.items + ) else ( + (* remove all other entries of this fiber in other streams' waiters. *) + cancel_all () + )); + (* item is returned to waiting caller through enqueue and enter_unchecked. *) + enqueue (Result.map f r)) + end in + (* Register interest in all streams and return first available item. *) + let wait_for_stream streams_fns = begin + Suspend.enter_unchecked (fun ctx enqueue -> List.iter (wait ctx enqueue) streams_fns) + end in + wait_for_stream streams_fns + let length t = Mutex.lock t.mutex; let len = Queue.length t.items in @@ -125,6 +164,13 @@ let take_nonblocking = function | Sync x -> Sync.take_nonblocking x | Locking x -> Locking.take_nonblocking x +let select streams = + let filter s = match s with + | (Sync _, _) -> assert false + | (Locking x, f) -> (x, f) + in + Locking.select_of_many (List.map filter streams) + let length = function | Sync _ -> 0 | Locking x -> Locking.length x diff --git a/lib_eio/stream.mli b/lib_eio/stream.mli index 6554cac1a..79b7075b6 100644 --- a/lib_eio/stream.mli +++ b/lib_eio/stream.mli @@ -40,6 +40,10 @@ val take_nonblocking : 'a t -> 'a option Note that if another domain may add to the stream then a [None] result may already be out-of-date by the time this returns. *) +val select : ('a t * ('a -> 'b)) list -> 'b +(** [select] returns the first item yielded by any stream. This only + works for streams with non-zero capacity. *) + val length : 'a t -> int (** [length t] returns the number of items currently in [t]. *) diff --git a/lib_eio/waiters.ml b/lib_eio/waiters.ml index c0cbd4624..99c21155e 100644 --- a/lib_eio/waiters.ml +++ b/lib_eio/waiters.ml @@ -38,11 +38,12 @@ let rec wake_one t v = let is_empty = Lwt_dllist.is_empty -let await_internal ~mutex (t:'a t) id ctx enqueue = +let cancellable_await_internal ~mutex (t:'a t) id ctx enqueue = match Fiber_context.get_error ctx with | Some ex -> Option.iter Mutex.unlock mutex; - enqueue (Error ex) + enqueue (Error ex); + fun () -> () | None -> let resolved_waiter = ref Hook.null in let finished = Atomic.make false in @@ -56,14 +57,24 @@ let await_internal ~mutex (t:'a t) id ctx enqueue = enqueue (Error ex) ) in + let unwait () = + if Atomic.compare_and_set finished false true + then Hook.remove !resolved_waiter + in Fiber_context.set_cancel_fn ctx cancel; let waiter = { enqueue; finished } in match mutex with | None -> - resolved_waiter := add_waiter t waiter + resolved_waiter := add_waiter t waiter; + unwait | Some mutex -> resolved_waiter := add_waiter_protected ~mutex t waiter; - Mutex.unlock mutex + Mutex.unlock mutex; + unwait + +let await_internal ~mutex (t: 'a t) id ctx enqueue = + let _cancel = (cancellable_await_internal ~mutex t id ctx enqueue) in + () (* Returns a result if the wait succeeds, or raises if cancelled. *) let await ~mutex waiters id = diff --git a/lib_eio/waiters.mli b/lib_eio/waiters.mli index 724cf96e7..04b8d4557 100644 --- a/lib_eio/waiters.mli +++ b/lib_eio/waiters.mli @@ -27,8 +27,8 @@ val await : If [t] can be used from multiple domains: - [mutex] must be set to the mutex to use to unlock it. - [mutex] must be already held when calling this function, which will unlock it before blocking. - When [await] returns, [mutex] will have been unlocked. - @raise Cancel.Cancelled if the fiber's context is cancelled *) + When [await] returns, [mutex] will have been unlocked. + @raise Cancel.Cancelled if the fiber's context is cancelled *) val await_internal : mutex:Mutex.t option -> @@ -40,3 +40,12 @@ val await_internal : Note: [enqueue] is called from the triggering domain, which is currently calling {!wake_one} or {!wake_all} and must therefore be holding [mutex]. *) + +val cancellable_await_internal : + mutex:Mutex.t option -> + 'a t -> Ctf.id -> Fiber_context.t -> + (('a, exn) result -> unit) -> (unit -> unit) +(** Like [await_internal], but returns a function which, when called, + removes the current fiber continuation from the waiters list. + This is used when a fiber is waiting for multiple [Waiter]s simultaneously, + and needs to remove itself from other waiters once it has been enqueued by one.*) diff --git a/tests/stream.md b/tests/stream.md index c5a035e3b..10771d00d 100644 --- a/tests/stream.md +++ b/tests/stream.md @@ -357,3 +357,22 @@ Non-blocking take with zero-capacity stream: +Got None from stream - : unit = () ``` + +Selecting from multiple channels: + +```ocaml +# run @@ fun () -> Switch.run (fun sw -> + let t1, t2 = (S.create 2), (S.create 2) in + let selector = [(t1, fun x -> x); (t2, fun x -> x)] in + Fiber.fork ~sw (fun () -> S.add t2 "foo"); + Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector)); + Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector)); + Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector)); + Fiber.fork ~sw (fun () -> S.add t2 "bar"); + Fiber.fork ~sw (fun () -> S.add t1 "baz"); + ) ++foo ++bar ++baz +- : unit = () +```