diff --git a/lib_eio/eio.ml b/lib_eio/eio.ml index f199e02d4..2445d7cad 100644 --- a/lib_eio/eio.ml +++ b/lib_eio/eio.ml @@ -8,6 +8,7 @@ module Semaphore = Semaphore module Mutex = Eio_mutex module Condition = Condition module Stream = Stream +module Workpool = Workpool module Exn = Exn module Resource = Resource module Flow = Flow diff --git a/lib_eio/eio.mli b/lib_eio/eio.mli index 8a8a32d8e..797c6a2bb 100644 --- a/lib_eio/eio.mli +++ b/lib_eio/eio.mli @@ -39,6 +39,9 @@ module Stream = Stream (** Cancelling fibers. *) module Cancel = Eio__core.Cancel +(** A high-level domain workpool *) +module Workpool = Workpool + (** Commonly used standard features. This module is intended to be [open]ed. *) module Std = Std diff --git a/lib_eio/workpool.ml b/lib_eio/workpool.ml new file mode 100644 index 000000000..cb05a8c0e --- /dev/null +++ b/lib_eio/workpool.ml @@ -0,0 +1,146 @@ +type job = Pack : (unit -> 'a) * ('a, exn) Result.t Promise.u -> job + +type action = + | Process of job + | Quit of { + atomic: int Atomic.t; + target: int; + all_done: unit Promise.u; + } + +(* Worker: 1 domain/thread + m jobs per worker, n domains per workpool *) + +type t = { + (* The work queue *) + stream: action Stream.t; + (* Number of domains. Depending on settings, domains may run more than 1 job at a time. *) + domain_count: int; + (* True when [Workpool.terminate] has been called. *) + is_terminating: bool Atomic.t; + (* Resolved when the workpool begins terminating. *) + terminating: action Promise.t * action Promise.u; + (* Resolved when the workpool has terminated. *) + terminated: unit Promise.t * unit Promise.u; +} + +let reject (Pack (_, w)) = Promise.resolve_error w (Failure "Workpool.terminate called") + +(* This function is the core of workpool.ml. + Each worker recursively calls [loop ()] until the [terminating] + promise is resolved. Workers pull one job at a time from the Stream. *) +let start_worker ~sw ~limit ~terminating stream = + let capacity = ref limit in + let condition = Condition.create () in + let run_job job w = + Fiber.fork ~sw (fun () -> + decr capacity; + Promise.resolve w + (try Ok (job ()) with + | exn -> Error exn); + incr capacity; + Condition.broadcast condition ) + in + (* The main worker loop. *) + let rec loop () = + let actions = Fiber.n_any [ (fun () -> Promise.await terminating); (fun () -> Stream.take stream) ] in + match actions with + | [ Process (Pack (job, w)) ] -> + (* We start the job right away. This also gives a chance to other domains + to start waiting on the Stream before the current thread blocks on [Stream.take] again. *) + run_job job w; + while !capacity = 0 do + Condition.await_no_mutex condition + done; + (loop [@tailcall]) () + | Quit { atomic; target; all_done } :: maybe_job -> + List.iter + (function + | Process job -> reject job + | _ -> assert false) + maybe_job; + (* Wait until the completion of all of this worker's jobs. *) + while !capacity < limit do + Condition.await_no_mutex condition + done; + (* If we're the last worker terminating, resolve the promise. *) + if Atomic.fetch_and_add atomic 1 = target then Promise.resolve all_done () + | _ -> assert false + in + loop () + +(* Start a new domain. The worker will need a switch, then we start the worker. *) +let start_domain ~sw ~domain_mgr ~limit ~terminating ~transient stream = + let go () = + Domain_manager.run domain_mgr (fun () -> + Switch.run @@ fun sw -> start_worker ~sw ~limit ~terminating stream ) + in + (* [transient] workpools run as daemons to not hold the user's switch from completing. + It's up to the user to hold the switch open (and thus, the workpool) + by blocking on the jobs issued to the workpool. + [Workpool.run] and [Workpool.run_exn] will block so this shouldn't be a problem. + Still, the user can call [Workpool.create] with [~transient:false] to + disable this behavior, in which case the user must call [Workpool.terminate] + to release the switch. *) + match transient with + | false -> Fiber.fork ~sw go + | true -> + Fiber.fork_daemon ~sw (fun () -> + go (); + `Stop_daemon ) + +let create ~sw ~domain_count ~domain_concurrency ?(capacity = 0) ?(transient = true) domain_mgr = + if capacity < 0 then raise (Invalid_argument "Workpool capacity < 0"); + let stream = Stream.create capacity in + let instance = + { + stream; + domain_count; + is_terminating = Atomic.make false; + terminating = Promise.create (); + terminated = Promise.create (); + } + in + let terminating = fst instance.terminating in + for _ = 1 to domain_count do + start_domain ~sw ~domain_mgr ~limit:domain_concurrency ~terminating ~transient stream + done; + instance + +let run_promise ~sw { stream; _ } f = + let p, w = Promise.create () in + Fiber.fork_promise ~sw (fun () -> + Stream.add stream (Process (Pack (f, w))); + Promise.await_exn p ) + +let run { stream; _ } f = + let p, w = Promise.create () in + Stream.add stream (Process (Pack (f, w))); + Promise.await p + +let run_exn instance f = + match run instance f with + | Ok x -> x + | Error exn -> raise exn + +let terminate ~sw ({ terminating = _, w1; terminated = p2, w2; _ } as instance) = + if Atomic.compare_and_set instance.is_terminating false true + then ( + (* Instruct workers to shutdown *) + Promise.resolve w1 (Quit { atomic = Atomic.make 1; target = instance.domain_count; all_done = w2 }); + (* Reject all present and future queued jobs *) + Fiber.fork_daemon ~sw (fun () -> + while true do + match Stream.take instance.stream with + | Process job -> reject job + | _ -> assert false + done; + `Stop_daemon ); + (* Wait for all workers to have shutdown *) + Promise.await p2 ) + else (* [Workpool.terminate] was called more than once. *) + Promise.await p2 + +let is_terminating { terminating = p, _; _ } = Promise.is_resolved p + +let is_terminated { terminated = p, _; _ } = Promise.is_resolved p diff --git a/lib_eio/workpool.mli b/lib_eio/workpool.mli new file mode 100644 index 000000000..382e36f4d --- /dev/null +++ b/lib_eio/workpool.mli @@ -0,0 +1,40 @@ +type t + +(** Creates a new workpool with [domain_count]. + + [domain_concurrency] is the maximum number of jobs that each domain can run at a time. + + [capacity] (default: 0) is identical to the [Eio.Stream.create] capacity parameter. + + [transient] (default: true). When true, the workpool will not block the [~sw] Switch from completing. + When false, you must call [terminate] to release the [~sw] Switch. *) +val create : + sw:Switch.t -> + domain_count:int -> + domain_concurrency:int -> + ?capacity:int -> + ?transient:bool -> + #Domain_manager.t -> + t + +(** Run a job on this workpool. It is placed at the end of the queue. *) +val run : t -> (unit -> 'a) -> ('a, exn) result + +(** Same as [run] but raises if the job failed. *) +val run_exn : t -> (unit -> 'a) -> 'a + +(** Same as [run] but returns immediately, without blocking. *) +val run_promise : sw:Switch.t -> t -> (unit -> 'a) -> ('a, exn) result Promise.t + +(** Waits for all running jobs to complete, then returns. + No new jobs are started, even if they were already enqueued. + To abort all running jobs instead of waiting for them, call [Switch.fail] on the Switch used to create this workpool *) +val terminate : sw:Switch.t -> t -> unit + +(** Returns true if the [terminate] function has been called on this workpool. + Also returns true if the workpool has fully terminated. *) +val is_terminating : t -> bool + +(** Returns true if the [terminate] function has been called on this workpool AND + the workpool has fully terminated (all running jobs have completed). *) +val is_terminated : t -> bool diff --git a/tests/workpool.md b/tests/workpool.md new file mode 100644 index 000000000..ef1b2bbc5 --- /dev/null +++ b/tests/workpool.md @@ -0,0 +1,273 @@ +# Setting up the environment + +```ocaml +# #require "eio_main";; +``` + +Creating some useful helper functions + +```ocaml +open Eio.Std + +module Workpool = Eio.Workpool + +let () = Eio.Exn.Backend.show := false + +let sleep mono_clock ms = Eio.Time.Mono.sleep mono_clock ((Int.to_float ms) /. 1000.0) + +let epsilon = 20 (* milliseconds *) + +let around mono_clock max_ms f = + let t0 = Eio.Time.Mono.now mono_clock in + let res = f () in + let t1 = Eio.Time.Mono.now mono_clock in + let expected_top = Mtime.Span.((max_ms + epsilon) * ms) in + let expected_bottom = Mtime.Span.((max 0 (max_ms - epsilon)) * ms) in + let actual = Mtime.span t0 t1 in + if Mtime.Span.compare actual expected_top <= 0 && Mtime.Span.compare actual expected_bottom >= 0 + then res + else failwith (Format.asprintf "Duration not %a >= %a =< %a" + Mtime.Span.pp expected_bottom Mtime.Span.pp actual Mtime.Span.pp expected_top + ) + +let run fn = + Eio_main.run @@ fun env -> + let mono_clock = (Eio.Stdenv.mono_clock env) in + fn (Eio.Stdenv.domain_mgr env) (sleep mono_clock) (around mono_clock) +``` + +# Workpool.create + +Workpool is created, transient by default: + +```ocaml +# run @@ fun mgr sleep around -> + Switch.run @@ fun sw -> + ignore @@ Workpool.create + ~sw ~domain_count:2 ~domain_concurrency:1 mgr + ;; +- : unit = () +``` + +Workpool holds up the switch when non-transient: + +```ocaml +# run @@ fun mgr sleep around -> + let terminated = ref false in + Switch.run (fun sw -> + let wp = + Workpool.create + ~sw ~domain_count:2 ~domain_concurrency:1 mgr ~transient:false + in + Fiber.fork_daemon ~sw (fun () -> + Fiber.yield (); + terminated := true; + Workpool.terminate ~sw wp; + `Stop_daemon + ) + ); + !terminated + ;; +- : bool = true +``` + +# Concurrency + +Runs jobs in parallel as much as possible (domains): + +```ocaml +# run @@ fun mgr sleep around -> + Switch.run @@ fun sw -> + let total = Atomic.make 0 in + let wp = Workpool.create ~sw ~domain_count:2 ~domain_concurrency:1 mgr in + around 300 (fun () -> + List.init 5 (fun i -> i + 1) + |> Fiber.List.iter (fun i -> Workpool.run_exn wp (fun () -> + sleep 100; + ignore @@ Atomic.fetch_and_add total i + )); + Atomic.get total + );; +- : int = 15 +``` + +Runs jobs in parallel as much as possible (workers): + +```ocaml +# run @@ fun mgr sleep around -> + Switch.run @@ fun sw -> + let total = Atomic.make 0 in + let wp = Workpool.create ~sw ~domain_count:1 ~domain_concurrency:2 mgr in + around 300 (fun () -> + List.init 5 (fun i -> i + 1) + |> Fiber.List.iter (fun i -> Workpool.run_exn wp (fun () -> + sleep 100; + ignore @@ Atomic.fetch_and_add total i + )); + Atomic.get total + );; +- : int = 15 +``` + +Runs jobs in parallel as much as possible (both): + +```ocaml +# run @@ fun mgr sleep around -> + Switch.run @@ fun sw -> + let total = Atomic.make 0 in + let wp = Workpool.create ~sw ~domain_count:2 ~domain_concurrency:2 mgr in + around 200 (fun () -> + List.init 5 (fun i -> i + 1) + |> Fiber.List.iter (fun i -> Workpool.run_exn wp (fun () -> + sleep 100; + ignore @@ Atomic.fetch_and_add total i + )); + Atomic.get total + );; +- : int = 15 +``` + +# Job error handling + +`Workpool.run` returns a Result: + +```ocaml +# run @@ fun mgr sleep around -> + Switch.run @@ fun sw -> + let total = Atomic.make 0 in + let wp = Workpool.create ~sw ~domain_count:1 ~domain_concurrency:4 mgr in + around 200 (fun () -> + let results = + List.init 5 (fun i -> i + 1) + |> Fiber.List.map (fun i -> Workpool.run wp (fun () -> + sleep 100; + if i mod 2 = 0 + then failwith (Int.to_string i) + else (Atomic.fetch_and_add total i) + )) + in + results, Atomic.get total + );; +- : (int, exn) result list * int = +([Ok 0; Error (Failure "2"); Ok 1; Error (Failure "4"); Ok 4], 9) +``` + +`Workpool.run_exn` raises: + +```ocaml +# run @@ fun mgr sleep around -> + Switch.run @@ fun sw -> + let total = Atomic.make 0 in + let wp = Workpool.create ~sw ~domain_count:1 ~domain_concurrency:2 mgr in + around 200 (fun () -> + List.init 5 (fun i -> i + 1) + |> Fiber.List.map (fun i -> Workpool.run_exn wp (fun () -> + traceln "Started %d" i; + let x = Atomic.fetch_and_add total i in + if x = 3 + then failwith (Int.to_string i) + else x + )) + );; ++Started 1 ++Started 2 ++Started 3 ++Started 4 +Exception: Failure "3". +``` + +# Blocking for capacity + +`Workpool.run` will block waiting for room in the queue: + +```ocaml +# run @@ fun mgr sleep around -> + Switch.run @@ fun sw -> + let wp = Workpool.create ~sw ~domain_count:1 ~domain_concurrency:1 mgr in + + let p1 = Fiber.fork_promise ~sw (fun () -> Workpool.run_exn wp (fun () -> sleep 100)) in + + around 100 (fun () -> Workpool.run_exn wp @@ fun () -> ()); + + around 0 (fun () -> Promise.await_exn p1) + ;; +- : unit = () +``` + +`Workpool.run_promise` will not block if there's not enough room in the queue: + +```ocaml +# run @@ fun mgr sleep around -> + Switch.run @@ fun sw -> + let wp = Workpool.create ~sw ~domain_count:1 ~domain_concurrency:1 mgr in + + let p1 = around 0 (fun () -> + Fiber.fork_promise ~sw (fun () -> Workpool.run_exn wp (fun () -> sleep 100)) + ) + in + let p2 = around 0 (fun () -> + Fiber.fork_promise ~sw (fun () -> Workpool.run_exn wp (fun () -> sleep 100)) + ) + in + let p3 = around 0 (fun () -> + Workpool.run_promise ~sw wp (fun () -> ()) + ) + in + + around 200 (fun () -> + Promise.await_exn p1; + Promise.await_exn p2; + Promise.await_exn p3; + (* Value restriction :( *) + Promise.create_resolved (Ok ()) + ) + |> Promise.await_exn + ;; +- : unit = () +``` + +# Termination + +`Workpool.terminate` waits for jobs currently running to finish and rejects queued jobs: + +```ocaml +# run @@ fun mgr sleep around -> + let print_status wp = + traceln "Terminating: %b (terminated: %b)" + (Workpool.is_terminating wp) (Workpool.is_terminated wp) + in + Switch.run @@ fun sw -> + let total = Atomic.make 0 in + let wp = Workpool.create ~sw ~domain_count:2 ~domain_concurrency:2 mgr in + let results = Fiber.fork_promise ~sw (fun () -> + around 300 (fun () -> + List.init 5 (fun i -> i + 1) + |> Fiber.List.iter (fun i -> Workpool.run_exn wp (fun () -> + sleep 150; + ignore @@ Atomic.fetch_and_add total i + )); + Atomic.get total + ) + ) + in + sleep 75; + (* Exactly one job should be left in the queue + for Workpool.terminate to reject *) + let x = around 75 (fun () -> + print_status wp; + let p = Fiber.fork_promise ~sw (fun () -> Workpool.terminate ~sw wp) in + print_status wp; + Promise.await_exn p; + print_status wp; + Atomic.get total + ) + in + traceln "Total: %d (terminated: %b)" x (Workpool.is_terminated wp); + Promise.await_exn results + ;; ++Terminating: false (terminated: false) ++Terminating: true (terminated: false) ++Terminating: true (terminated: true) ++Total: 10 (terminated: true) +Exception: Failure "Workpool.terminate called". +```