Skip to content

Commit

Permalink
Safe Fiber races: ~combine and n_any
Browse files Browse the repository at this point in the history
Co-authored-by: Thomas Leonard <[email protected]>
  • Loading branch information
SGrondin and talex5 committed Jan 3, 2024
1 parent c9db164 commit cbb6ece
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 21 deletions.
19 changes: 14 additions & 5 deletions lib_eio/core/eio__core.mli
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ module Fiber : sig
(** [all fs] is like [both], but for any number of fibers.
[all []] returns immediately. *)

val first : (unit -> 'a) -> (unit -> 'a) -> 'a
val first : ?combine:('a -> 'a -> 'a) -> (unit -> 'a) -> (unit -> 'a) -> 'a
(** [first f g] runs [f ()] and [g ()] concurrently.
They run in a new cancellation sub-context, and when one finishes the other is cancelled.
Expand All @@ -216,15 +216,24 @@ module Fiber : sig
If both fibers fail, {!Exn.combine} is used to combine the exceptions.
Warning: it is always possible that {i both} operations will succeed (and one result will be thrown away).
This is because there is a period of time after the first operation succeeds,
but before its fiber finishes, during which the other operation may also succeed. *)
Warning: it is always possible that {i both} operations will succeed.
This is because there is a period of time after the first operation succeeds
when it is waiting in the run-queue to resume
during which the other operation may also succeed.
val any : (unit -> 'a) list -> 'a
If both fibers succeed, [combine a b] is used to combine the results
(where [a] is the result of the first fiber to return and [b] is the second result).
The default is [fun a _ -> a], which discards the later result. *)

val any : ?combine:('a -> 'a -> 'a) -> (unit -> 'a) list -> 'a
(** [any fs] is like [first], but for any number of fibers.
[any []] just waits forever (or until cancelled). *)

val n_any : (unit -> 'a) list -> 'a list
(** [n_any fs] is like [any], expect that if multiple fibers return values
then they are all returned, in the order in which the fibers finished. *)

val await_cancel : unit -> 'a
(** [await_cancel ()] waits until cancelled.
@raise Cancel.Cancelled *)
Expand Down
41 changes: 26 additions & 15 deletions lib_eio/core/fiber.ml
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,22 @@ let await_cancel () =
Suspend.enter "await_cancel" @@ fun fiber enqueue ->
Cancel.Fiber_context.set_cancel_fn fiber (fun ex -> enqueue (Error ex))

let any fs =
let r = ref `None in
type 'a any_status =
| New
| Ex of (exn * Printexc.raw_backtrace)
| OK of 'a

let any_gen ~return ~combine fs =
let r = ref New in
let parent_c =
Cancel.sub_unchecked Any (fun cc ->
let wrap h =
match h () with
| x ->
begin match !r with
| `None -> r := `Ok x; Cancel.cancel cc Not_first
| `Ex _ | `Ok _ -> ()
| New -> r := OK (return x); Cancel.cancel cc Not_first
| OK prev -> r := OK (combine prev x)
| Ex _ -> ()
end
| exception Cancel.Cancelled _ when not (Cancel.is_on cc) ->
(* If this is in response to us asking the fiber to cancel then we can just ignore it.
Expand All @@ -105,11 +111,11 @@ let any fs =
()
| exception ex ->
begin match !r with
| `None -> r := `Ex (ex, Printexc.get_raw_backtrace ()); Cancel.cancel cc ex
| `Ok _ -> r := `Ex (ex, Printexc.get_raw_backtrace ())
| `Ex prev ->
| New -> r := Ex (ex, Printexc.get_raw_backtrace ()); Cancel.cancel cc ex
| OK _ -> r := Ex (ex, Printexc.get_raw_backtrace ())
| Ex prev ->
let bt = Printexc.get_raw_backtrace () in
r := `Ex (Exn.combine prev (ex, bt))
r := Ex (Exn.combine prev (ex, bt))
end
in
let vars = Cancel.Fiber_context.get_vars () in
Expand All @@ -121,7 +127,7 @@ let any fs =
let p, r = Promise.create_with_id (Cancel.Fiber_context.tid new_fiber) in
fork_raw new_fiber (fun () ->
match wrap f with
| x -> Promise.resolve_ok r x
| () -> Promise.resolve_ok r ()
| exception ex -> Promise.resolve_error r ex
);
p :: aux fs
Expand All @@ -131,16 +137,21 @@ let any fs =
)
in
match !r, Cancel.get_error parent_c with
| `Ok r, None -> r
| (`Ok _ | `None), Some ex -> raise ex
| `Ex (ex, bt), None -> Printexc.raise_with_backtrace ex bt
| `Ex ex1, Some ex2 ->
| OK r, None -> r
| (OK _ | New), Some ex -> raise ex
| Ex (ex, bt), None -> Printexc.raise_with_backtrace ex bt
| Ex ex1, Some ex2 ->
let bt2 = Printexc.get_raw_backtrace () in
let ex, bt = Exn.combine ex1 (ex2, bt2) in
Printexc.raise_with_backtrace ex bt
| `None, None -> assert false
| New, None -> assert false

let n_any fs =
List.rev (any_gen fs ~return:(fun x -> [x]) ~combine:(fun xs x -> x :: xs))

let any ?(combine=(fun x _ -> x)) fs = any_gen fs ~return:Fun.id ~combine

let first f g = any [f; g]
let first ?combine f g = any ?combine [f; g]

let is_cancelled () =
let ctx = Effect.perform Cancel.Get_context in
Expand Down
117 changes: 116 additions & 1 deletion tests/fiber.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Second finishes, first is cancelled:
- : unit = ()
```

If both succeed, we pick the first one:
If both succeed and no ~combine, we pick the first one by default:

```ocaml
# run @@ fun () ->
Expand All @@ -49,6 +49,73 @@ If both succeed, we pick the first one:
- : unit = ()
```

If both succeed we let ~combine decide:

```ocaml
# run @@ fun () ->
Fiber.first ~combine:(fun _ x -> x)
(fun () -> "a")
(fun () -> "b");;
+b
- : unit = ()
```

It allows for safe Stream.take races (both):

```ocaml
# run @@ fun () ->
let stream = Eio.Stream.create 1 in
Fiber.first ~combine:(fun x y -> x ^ y)
(fun () ->
Fiber.yield ();
Eio.Stream.add stream "b";
"a"
)
(fun () -> Eio.Stream.take stream);;
+ab
- : unit = ()
```

It allows for safe Stream.take races (f is first):

```ocaml
# run @@ fun () ->
let stream = Eio.Stream.create 1 in
let out =
Fiber.first ~combine:(fun x y -> x ^ y)
(fun () ->
Eio.Stream.add stream "b";
Fiber.yield ();
"a"
)
(fun () ->
Fiber.yield ();
Eio.Stream.take stream)
in
out ^ Int.to_string (Eio.Stream.length stream);;
+a1
- : unit = ()
```

It allows for safe Stream.take races (g is first):

```ocaml
# run @@ fun () ->
let stream = Eio.Stream.create 1 in
let out =
Fiber.first ~combine:(fun x y -> x ^ y)
(fun () ->
Eio.Stream.add stream "b";
Fiber.yield ();
"a"
)
(fun () -> Eio.Stream.take stream)
in
out ^ Int.to_string (Eio.Stream.length stream);;
+b0
- : unit = ()
```

One crashes - report it:

```ocaml
Expand Down Expand Up @@ -201,6 +268,54 @@ Exception: Stdlib.Exit.
- : unit = ()
```

`Fiber.any` with combine collects all results:

```ocaml
# run @@ fun () ->
Fiber.any
~combine:(fun x y -> x @ y)
(List.init 3 (fun x () -> traceln "%d" x; [x]))
|> Fmt.(str "%a" (Dump.list int));;
+0
+1
+2
+[0; 1; 2]
- : unit = ()
```

# Fiber.n_any

`Fiber.n_any` behaves just like `Fiber.any` when there's only one result:

```ocaml
# run @@ fun () ->
Fiber.n_any (List.init 3 (fun x () -> traceln "%d" x; Fiber.yield (); x))
|> Fmt.(str "%a" (Dump.list int));;
+0
+1
+2
+[0]
- : unit = ()
```

`Fiber.n_any` collects all results:

```ocaml
# run @@ fun () ->
(Fiber.n_any (List.init 4 (fun x () ->
traceln "%d" x;
if x = 1 then Fiber.yield ();
x
)))
|> Fmt.(str "%a" (Dump.list int));;
+0
+1
+2
+3
+[0; 2; 3]
- : unit = ()
```

# Fiber.await_cancel

```ocaml
Expand Down

0 comments on commit cbb6ece

Please sign in to comment.