Skip to content

Commit

Permalink
Implement simple atomic stream select
Browse files Browse the repository at this point in the history
  • Loading branch information
dermesser committed Jul 12, 2023
1 parent 57ace76 commit 11c0c6e
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 6 deletions.
46 changes: 46 additions & 0 deletions lib_eio/stream.ml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,45 @@ module Locking = struct
Mutex.unlock t.mutex;
Some v

let select_of_many streams_fns =
let finished = Atomic.make false in
let cancel_fns = ref [] in
let add_cancel_fn fn = cancel_fns := fn :: !cancel_fns in
let cancel_all () = List.iter (fun fn -> fn ()) !cancel_fns in
let wait ctx enqueue (t, f) = begin
Mutex.lock t.mutex;
(* First check if any items are already available and return early if there are. *)
if not (Queue.is_empty t.items)
then (
cancel_all ();
Mutex.unlock t.mutex;
enqueue (Ok (f (Queue.take t.items))))
else add_cancel_fn @@
(* Otherwise, register interest in this stream. *)
Waiters.cancellable_await_internal ~mutex:(Some t.mutex) t.readers t.id ctx (fun r ->
if Result.is_ok r then (
if not (Atomic.compare_and_set finished false true) then (
(* Another stream has yielded an item in the meantime. However, as
we have been waiting on this stream it must have been empty.
As the stream's mutex was held since before last checking for an item,
the queue must be empty.
*)
assert ((Queue.length t.items) < t.capacity);
Queue.add (Result.get_ok r) t.items
) else (
(* remove all other entries of this fiber in other streams' waiters. *)
cancel_all ()
));
(* item is returned to waiting caller through enqueue and enter_unchecked. *)
enqueue (Result.map f r))
end in
(* Register interest in all streams and return first available item. *)
let wait_for_stream streams_fns = begin
Suspend.enter_unchecked (fun ctx enqueue -> List.iter (wait ctx enqueue) streams_fns)
end in
wait_for_stream streams_fns

let length t =
Mutex.lock t.mutex;
let len = Queue.length t.items in
Expand Down Expand Up @@ -125,6 +164,13 @@ let take_nonblocking = function
| Sync x -> Sync.take_nonblocking x
| Locking x -> Locking.take_nonblocking x

let select streams =
let filter s = match s with
| (Sync _, _) -> assert false
| (Locking x, f) -> (x, f)
in
Locking.select_of_many (List.map filter streams)

let length = function
| Sync _ -> 0
| Locking x -> Locking.length x
Expand Down
4 changes: 4 additions & 0 deletions lib_eio/stream.mli
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ val take_nonblocking : 'a t -> 'a option
Note that if another domain may add to the stream then a [None]
result may already be out-of-date by the time this returns. *)

val select : ('a t * ('a -> 'b)) list -> 'b
(** [select] returns the first item yielded by any stream. This only
works for streams with non-zero capacity. *)

val length : 'a t -> int
(** [length t] returns the number of items currently in [t]. *)

Expand Down
19 changes: 15 additions & 4 deletions lib_eio/waiters.ml
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ let rec wake_one t v =

let is_empty = Lwt_dllist.is_empty

let await_internal ~mutex (t:'a t) id ctx enqueue =
let cancellable_await_internal ~mutex (t:'a t) id ctx enqueue =
match Fiber_context.get_error ctx with
| Some ex ->
Option.iter Mutex.unlock mutex;
enqueue (Error ex)
enqueue (Error ex);
fun () -> ()
| None ->
let resolved_waiter = ref Hook.null in
let finished = Atomic.make false in
Expand All @@ -56,14 +57,24 @@ let await_internal ~mutex (t:'a t) id ctx enqueue =
enqueue (Error ex)
)
in
let unwait () =
if Atomic.compare_and_set finished false true
then Hook.remove !resolved_waiter
in
Fiber_context.set_cancel_fn ctx cancel;
let waiter = { enqueue; finished } in
match mutex with
| None ->
resolved_waiter := add_waiter t waiter
resolved_waiter := add_waiter t waiter;
unwait
| Some mutex ->
resolved_waiter := add_waiter_protected ~mutex t waiter;
Mutex.unlock mutex
Mutex.unlock mutex;
unwait

let await_internal ~mutex (t: 'a t) id ctx enqueue =
let _cancel = (cancellable_await_internal ~mutex t id ctx enqueue) in
()

(* Returns a result if the wait succeeds, or raises if cancelled. *)
let await ~mutex waiters id =
Expand Down
13 changes: 11 additions & 2 deletions lib_eio/waiters.mli
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ val await :
If [t] can be used from multiple domains:
- [mutex] must be set to the mutex to use to unlock it.
- [mutex] must be already held when calling this function, which will unlock it before blocking.
When [await] returns, [mutex] will have been unlocked.
@raise Cancel.Cancelled if the fiber's context is cancelled *)
When [await] returns, [mutex] will have been unlocked.
@raise Cancel.Cancelled if the fiber's context is cancelled *)

val await_internal :
mutex:Mutex.t option ->
Expand All @@ -40,3 +40,12 @@ val await_internal :
Note: [enqueue] is called from the triggering domain,
which is currently calling {!wake_one} or {!wake_all}
and must therefore be holding [mutex]. *)

val cancellable_await_internal :
mutex:Mutex.t option ->
'a t -> Ctf.id -> Fiber_context.t ->
(('a, exn) result -> unit) -> (unit -> unit)
(** Like [await_internal], but returns a function which, when called,
removes the current fiber continuation from the waiters list.
This is used when a fiber is waiting for multiple [Waiter]s simultaneously,
and needs to remove itself from other waiters once it has been enqueued by one.*)
19 changes: 19 additions & 0 deletions tests/stream.md
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,22 @@ Non-blocking take with zero-capacity stream:
+Got None from stream
- : unit = ()
```

Selecting from multiple channels:

```ocaml
# run @@ fun () -> Switch.run (fun sw ->
let t1, t2 = (S.create 2), (S.create 2) in
let selector = [(t1, fun x -> x); (t2, fun x -> x)] in
Fiber.fork ~sw (fun () -> S.add t2 "foo");
Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector));
Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector));
Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector));
Fiber.fork ~sw (fun () -> S.add t2 "bar");
Fiber.fork ~sw (fun () -> S.add t1 "baz");
)
+foo
+bar
+baz
- : unit = ()
```

0 comments on commit 11c0c6e

Please sign in to comment.