diff --git a/bench/bench_run.ocaml4.ml b/bench/bench_run.ocaml4.ml new file mode 100644 index 00000000..ce98597c --- /dev/null +++ b/bench/bench_run.ocaml4.ml @@ -0,0 +1 @@ +let run_suite ~budgetf:_ = [] diff --git a/bench/bench_run.ocaml5.ml b/bench/bench_run.ocaml5.ml new file mode 100644 index 00000000..a6909902 --- /dev/null +++ b/bench/bench_run.ocaml5.ml @@ -0,0 +1,36 @@ +open Multicore_bench +open Picos_std_structured +module Multififos = Picos_mux_multififo + +let run_one_multififo ~budgetf ~n_domains ~n () = + let context = ref (Obj.magic ()) in + + let m = if n < 1_000_000 then 1_000_000 / n else 1 in + + let before _ = context := Multififos.context () in + let init _ = !context in + let wrap i context action = + if i = 0 then Multififos.run ~context action else action () + in + let work i context = + if i <> 0 then Multififos.runner_on_this_thread context + else + for _ = 1 to m do + Run.for_n n ignore + done + in + + let config = + Printf.sprintf "%d mfifo%s, run_n %d" n_domains + (if n_domains = 1 then "" else "s") + n + in + Times.record ~budgetf ~n_domains ~before ~init ~wrap ~work () + |> Times.to_thruput_metrics ~n:(n * m) ~singular:"ignore" ~config + +let run_suite ~budgetf = + Util.cross [ 1; 2; 4; 8 ] + [ 100; 1_000; 10_000; 100_000; 1_000_000; 10_000_000 ] + |> List.concat_map @@ fun (n_domains, n) -> + if Picos_domain.recommended_domain_count () < n_domains then [] + else run_one_multififo ~budgetf ~n_domains ~n () diff --git a/bench/dune b/bench/dune index a32c509a..aaadd683 100644 --- a/bench/dune +++ b/bench/dune @@ -23,6 +23,7 @@ (run %{test} -brief "Picos binaries") (run %{test} -brief "Bounded_q with Picos_std_sync") (run %{test} -brief "Memory usage") + (run %{test} -brief "Picos_std_structured.Run") ;; )) (foreign_stubs @@ -49,6 +50,11 @@ from (picos_mux.fifo -> scheduler.ocaml5.ml) (picos_mux.thread -> scheduler.ocaml4.ml)) + (select + bench_run.ml + from + (picos_mux.multififo -> bench_run.ocaml5.ml) + (-> bench_run.ocaml4.ml)) (select bench_fib.ml from diff --git a/bench/main.ml b/bench/main.ml index e26b3f38..0d0f7672 100644 --- a/bench/main.ml +++ b/bench/main.ml @@ -22,6 +22,7 @@ let benchmarks = ("Picos binaries", Bench_binaries.run_suite); ("Bounded_q with Picos_std_sync", Bench_bounded_q.run_suite); ("Memory usage", Bench_memory.run_suite); + ("Picos_std_structured.Run", Bench_run.run_suite); ] let () = Multicore_bench.Cmd.run ~benchmarks () diff --git a/lib/picos_std.structured/bundle.ml b/lib/picos_std.structured/bundle.ml index d5ae07c2..5f793c2a 100644 --- a/lib/picos_std.structured/bundle.ml +++ b/lib/picos_std.structured/bundle.ml @@ -16,18 +16,17 @@ type t = [ `Bundle ] tdt external config_as_atomic : t -> int Atomic.t = "%identity" -let config_terminated_bit = 0x01 -and config_callstack_mask = 0x3E -and config_callstack_shift = 1 -and config_one = 0x40 (* memory runs out before overflow *) +let config_on_return_terminate_bit = 0x01 +and config_on_terminate_raise_bit = 0x02 +and config_callstack_mask = 0x6C +and config_callstack_shift = 2 +and config_one = 0x80 (* memory runs out before overflow *) let flock_key : [ `Bundle | `Nothing ] tdt Fiber.FLS.t = Fiber.FLS.create () -let terminate_as callstack (Bundle { bundle = Packed bundle; _ } : t) = - Computation.cancel bundle Control.Terminate callstack - -let terminate ?callstack t = - terminate_as (Control.get_callstack_opt callstack) t +let terminate ?callstack (Bundle { bundle = Packed bundle; _ } : t) = + Computation.cancel bundle Control.Terminate + (Control.get_callstack_opt callstack) let terminate_after ?callstack (Bundle { bundle = Packed bundle; _ } : t) ~seconds = @@ -39,11 +38,13 @@ let error ?callstack (Bundle r as t : t) exn bt = terminate ?callstack t; Control.Errors.push r.errors exn bt end + else if + Atomic.get (config_as_atomic t) land config_on_terminate_raise_bit <> 0 + then terminate ?callstack t let decr (Bundle r as t : t) = let n = Atomic.fetch_and_add (config_as_atomic t) (-config_one) in if n < config_one * 2 then begin - terminate_as Control.empty_bt t; Trigger.signal r.finished end @@ -51,6 +52,10 @@ type _ pass = FLS : unit pass | Arg : t pass let[@inline never] no_flock () = invalid_arg "no flock" +let[@inline] on_terminate = function + | None | Some `Ignore -> `Ignore + | Some `Raise -> `Raise + let get_flock fiber = match Fiber.FLS.get fiber flock_key ~default:Nothing with | Bundle _ as t -> t @@ -58,6 +63,8 @@ let get_flock fiber = let await (Bundle r as t : t) fiber packed canceler outer = Fiber.set_computation fiber packed; + if Fiber.FLS.get fiber flock_key ~default:Nothing != outer then + Fiber.FLS.set fiber flock_key outer; let forbid = Fiber.exchange fiber ~forbid:true in let n = Atomic.fetch_and_add (config_as_atomic t) (-config_one) in if config_one * 2 <= n then begin @@ -66,14 +73,22 @@ let await (Bundle r as t : t) fiber packed canceler outer = write from being delayed after the [Trigger.await] below. *) if config_one <= Atomic.fetch_and_add (config_as_atomic t) 0 then Trigger.await r.finished |> ignore - end - else terminate_as Control.empty_bt t; + end; Fiber.set fiber ~forbid; - if Fiber.FLS.get fiber flock_key ~default:Nothing != outer then - Fiber.FLS.set fiber flock_key outer; let (Packed parent) = packed in Computation.detach parent canceler; Control.Errors.check r.errors; + begin + let (Packed bundle) = r.bundle in + match Computation.peek_exn bundle with + | _ -> () + | exception Computation.Running -> + Computation.cancel bundle Control.Terminate Control.empty_bt + | exception Control.Terminate + when Atomic.get (config_as_atomic t) land config_on_terminate_raise_bit + = 0 -> + () + end; Fiber.check fiber let[@inline never] raised exn t fiber packed canceler outer = @@ -84,7 +99,7 @@ let[@inline never] raised exn t fiber packed canceler outer = let[@inline never] returned value (t : t) fiber packed canceler outer = let config = Atomic.get (config_as_atomic t) in - if config land config_terminated_bit <> 0 then begin + if config land config_on_return_terminate_bit <> 0 then begin let callstack = let n = (config land config_callstack_mask) lsr config_callstack_shift in if n = 0 then None else Some n @@ -99,25 +114,30 @@ let join_after_realloc x fn t fiber packed canceler outer = | value -> returned value t fiber packed canceler outer | exception exn -> raised exn t fiber packed canceler outer -let join_after_pass (type a) ?callstack ?on_return (fn : a -> _) (pass : a pass) - = +let join_after_pass (type a) ?callstack ?on_return ?on_terminate (fn : a -> _) + (pass : a pass) = (* The sequence of operations below ensures that nothing is leaked. *) let (Bundle r as t : t) = - let terminated = + let config = match on_return with - | None | Some `Wait -> 0 - | Some `Terminate -> config_terminated_bit + | None | Some `Wait -> config_one + | Some `Terminate -> config_one lor config_on_return_terminate_bit in - let callstack = + let config = + match on_terminate with + | None | Some `Ignore -> config + | Some `Raise -> config lor config_on_terminate_raise_bit + in + let config = match callstack with - | None -> 0 + | None -> config | Some n -> - if n <= 0 then 0 + if n <= 0 then config else - Int.min n (config_callstack_mask lsr config_callstack_shift) - lsl config_callstack_shift + config + lor Int.min n (config_callstack_mask lsr config_callstack_shift) + lsl config_callstack_shift in - let config = config_one lor callstack lor terminated in let bundle = Computation.Packed (Computation.create ~mode:`LIFO ()) in let errors = Control.Errors.create () in let finished = Trigger.signaled in @@ -219,8 +239,8 @@ let fork_pass (type a) (Bundle r as t : t) thunk (pass : a pass) = let is_running (Bundle { bundle = Packed bundle; _ } : t) = Computation.is_running bundle -let join_after ?callstack ?on_return fn = - join_after_pass ?callstack ?on_return fn Arg +let join_after ?callstack ?on_return ?on_terminate fn = + join_after_pass ?callstack ?on_return ?on_terminate fn Arg let fork t thunk = fork_pass t thunk Arg let fork_as_promise t thunk = fork_as_promise_pass t thunk Arg diff --git a/lib/picos_std.structured/control.ml b/lib/picos_std.structured/control.ml index d2808f22..3bd517af 100644 --- a/lib/picos_std.structured/control.ml +++ b/lib/picos_std.structured/control.ml @@ -41,13 +41,15 @@ module Errors = struct | [ (exn, bt) ] -> Printexc.raise_with_backtrace exn bt | exn_bts -> check exn_bts [] - let rec push t exn bt backoff = - let before = Atomic.get t in - let after = (exn, bt) :: before in - if not (Atomic.compare_and_set t before after) then - push t exn bt (Backoff.once backoff) - - let push t exn bt = push t exn bt Backoff.default + let push t exn bt = + let backoff = ref Backoff.default in + while + let before = Atomic.get t in + let after = (exn, bt) :: before in + not (Atomic.compare_and_set t before after) + do + backoff := Backoff.once !backoff + done end let raise_if_canceled () = Fiber.check (Fiber.current ()) diff --git a/lib/picos_std.structured/dune b/lib/picos_std.structured/dune index fc6e5764..df0659f5 100644 --- a/lib/picos_std.structured/dune +++ b/lib/picos_std.structured/dune @@ -1,3 +1,15 @@ +(rule + (enabled_if + (<= 5.0.0 %{ocaml_version})) + (action + (copy for.ocaml5.ml for.ml))) + +(rule + (enabled_if + (< %{ocaml_version} 5.0.0)) + (action + (copy for.ocaml4.ml for.ml))) + (library (name picos_std_structured) (public_name picos_std.structured) diff --git a/lib/picos_std.structured/flock.ml b/lib/picos_std.structured/flock.ml index b097f33b..82e78be8 100644 --- a/lib/picos_std.structured/flock.ml +++ b/lib/picos_std.structured/flock.ml @@ -10,5 +10,5 @@ let error ?callstack exn_bt = Bundle.error (get ()) ?callstack exn_bt let fork_as_promise thunk = Bundle.fork_as_promise_pass (get ()) thunk FLS let fork action = Bundle.fork_pass (get ()) action FLS -let join_after ?callstack ?on_return fn = - Bundle.join_after_pass ?callstack ?on_return fn Bundle.FLS +let join_after ?callstack ?on_return ?on_terminate fn = + Bundle.join_after_pass ?callstack ?on_return ?on_terminate fn Bundle.FLS diff --git a/lib/picos_std.structured/for.ocaml4.ml b/lib/picos_std.structured/for.ocaml4.ml new file mode 100644 index 00000000..25d018ad --- /dev/null +++ b/lib/picos_std.structured/for.ocaml4.ml @@ -0,0 +1,57 @@ +type _ tdt = + | Empty : [> `Empty ] tdt + | Range : { + mutable lo : int; + hi : int; + parent : [ `Empty | `Range ] tdt; + } + -> [> `Range ] tdt + +let[@poll error] cas_lo (Range r : [ `Range ] tdt) before after = + r.lo == before + && begin + r.lo <- after; + true + end + +let rec for_out t (Range r as range : [ `Range ] tdt) action = + let lo_before = r.lo in + let n = r.hi - lo_before in + if 0 < n then begin + if Bundle.is_running t then begin + let lo_after = lo_before + 1 in + if cas_lo range lo_before lo_after then begin + try action lo_before + with exn -> Bundle.error t exn (Printexc.get_raw_backtrace ()) + end; + for_out t range action + end + end + else + match r.parent with + | Empty -> () + | Range _ as range -> for_out t range action + +let rec for_in t (Range r as range : [ `Range ] tdt) action = + let lo_before = r.lo in + let n = r.hi - lo_before in + if n <= 1 then for_out t range action + else + let lo_after = lo_before + (n asr 1) in + if cas_lo range lo_before lo_after then begin + Bundle.fork t (fun () -> for_in t range action); + let child = Range { lo = lo_before; hi = lo_after; parent = range } in + for_in t child action + end + else for_in t range action + +let for_n ?on_terminate n action = + if 0 < n then + if n = 1 then + try action 0 + with + | Control.Terminate when Bundle.on_terminate on_terminate == `Ignore -> + () + else + let range = Range { lo = 0; hi = n; parent = Empty } in + Bundle.join_after ?on_terminate @@ fun t -> for_in t range action diff --git a/lib/picos_std.structured/for.ocaml5.ml b/lib/picos_std.structured/for.ocaml5.ml new file mode 100644 index 00000000..97465b9a --- /dev/null +++ b/lib/picos_std.structured/for.ocaml5.ml @@ -0,0 +1,82 @@ +open Picos + +type per_fiber = { mutable lo : int; mutable hi : int } + +type _ tdt = + | Empty : [> `Empty ] tdt + | Range : { + mutable _lo : int; + hi : int; + parent : [ `Empty | `Range ] tdt; + } + -> [> `Range ] tdt + +external lo_as_atomic : [ `Range ] tdt -> int Atomic.t = "%identity" + +let rec for_out t (Range r as range : [ `Range ] tdt) per_fiber action = + let lo_before = Atomic.get (lo_as_atomic range) in + let n = r.hi - lo_before in + if 0 < n then begin + let lo_after = lo_before + 1 + (n asr 1) in + if Atomic.compare_and_set (lo_as_atomic range) lo_before lo_after then begin + per_fiber.lo <- lo_before; + per_fiber.hi <- lo_after; + while Bundle.is_running t && per_fiber.lo < per_fiber.hi do + try + while per_fiber.lo < per_fiber.hi do + let i = per_fiber.lo in + per_fiber.lo <- i + 1; + action i + done + with exn -> Bundle.error t exn (Printexc.get_raw_backtrace ()) + done + end; + for_out t range per_fiber action + end + else + match r.parent with + | Empty -> () + | Range _ as range -> for_out t range per_fiber action + +let rec for_in t (Range r as range : [ `Range ] tdt) per_fiber action = + let lo_before = Atomic.get (lo_as_atomic range) in + let n = r.hi - lo_before in + if n <= 1 then for_out t range per_fiber action + else + let lo_after = lo_before + (n asr 1) in + if Atomic.compare_and_set (lo_as_atomic range) lo_before lo_after then begin + Bundle.fork t (fun () -> for_in_enter t range action); + let child = Range { _lo = lo_before; hi = lo_after; parent = range } in + for_in t child per_fiber action + end + else for_in t range per_fiber action + +and for_in_enter bundle range action = + let per_fiber = { lo = 0; hi = 0 } in + let effc (type a) : + a Effect.t -> ((a, _) Effect.Deep.continuation -> _) option = function + | Fiber.Spawn _ | Fiber.Current | Computation.Cancel_after _ -> None + | _ -> + (* Might be blocking, so fork any remaining work to another fiber. *) + if per_fiber.lo < per_fiber.hi then begin + let range = + Range { _lo = per_fiber.lo; hi = per_fiber.hi; parent = Empty } + in + per_fiber.lo <- per_fiber.hi; + Bundle.fork bundle (fun () -> for_in_enter bundle range action) + end; + None + in + let handler = Effect.Deep.{ effc } in + Effect.Deep.try_with (for_in bundle range per_fiber) action handler + +let for_n ?on_terminate n action = + if 0 < n then + if n = 1 then + try action 0 + with + | Control.Terminate when Bundle.on_terminate on_terminate == `Ignore -> + () + else + let range = Range { _lo = 0; hi = n; parent = Empty } in + Bundle.join_after ?on_terminate @@ fun t -> for_in_enter t range action diff --git a/lib/picos_std.structured/picos_std_structured.mli b/lib/picos_std.structured/picos_std_structured.mli index 4f0daac8..23dec2c8 100644 --- a/lib/picos_std.structured/picos_std_structured.mli +++ b/lib/picos_std.structured/picos_std_structured.mli @@ -223,7 +223,11 @@ module Bundle : sig (** Represents a bundle of fibers. *) val join_after : - ?callstack:int -> ?on_return:[ `Terminate | `Wait ] -> (t -> 'a) -> 'a + ?callstack:int -> + ?on_return:[ `Terminate | `Wait ] -> + ?on_terminate:[ `Raise | `Ignore ] -> + (t -> 'a) -> + 'a (** [join_after scope] calls [scope] with a {{!t} bundle}. A call of [join_after] returns or raises only after [scope] has returned or raised and all {{!fork} forked} fibers have terminated. If [scope] raises an @@ -298,7 +302,11 @@ module Flock : sig *) val join_after : - ?callstack:int -> ?on_return:[ `Terminate | `Wait ] -> (unit -> 'a) -> 'a + ?callstack:int -> + ?on_return:[ `Terminate | `Wait ] -> + ?on_terminate:[ `Raise | `Ignore ] -> + (unit -> 'a) -> + 'a (** [join_after scope] creates a new flock for fibers, calls [scope] after setting current flock to the new flock, and restores the previous flock, if any after [scope] exits. The flock will be implicitly propagated to all @@ -356,28 +364,37 @@ module Flock : sig end module Run : sig - (** Operations for running fibers in specific patterns. *) - - val all : (unit -> unit) list -> unit - (** [all actions] starts the actions as separate fibers and waits until they - all return. If any of the actions raises an exception other than - {{!Control.Terminate} [Terminate]} the remaining fibers will be canceled - and the exception, if any, will be raised. - - ⚠️ One of the actions may be run on the current fiber. - - ⚠️ It is not guaranteed that any of the actions in the list are called. In - particular, after any action raises an exception other than - {{!Control.Terminate} [Terminate]} or after the main fiber is canceled, - the actions that have not yet started may be skipped entirely. - - [all] is roughly equivalent to + (** Operations for running actions concurrently. + + ⚠️ In general, when an action expected to return the unit value [()] + started by an operation in this module raises an unhandled exception, + other than {{!Control.Terminate} [Terminate]}, which is not counted as an + error, the whole operation will be canceled and the exception will be + raised. + + ⚠️ The operations in this module run their actions such that any action may + block to await without preventing other actions from being run. At the + limit every action may need to be run in a distinct fiber. However, it is + not guaranteed that every action always runs in a distinct fiber. The + actual number of fibers used can be much less than the number of actions + executed in case the actions do not block, complete quickly, and/or the + scheduler doesn't provide parallelism. + + ⚠️ The operations in this module do not guaranteed that any of the actions + are executed. In particular, after any action raises an unhandled + exception or after the main fiber is canceled, the actions that have not + yet started may be skipped entirely. *) + + val all : ?on_terminate:[ `Raise | `Ignore ] -> (unit -> unit) list -> unit + (** [all actions] starts the actions and waits until they all complete. + + [all] is roughly equivalent to: {[ - let all actions = - Bundle.join_after @@ fun bundle -> - List.iter (Bundle.fork bundle) actions - ]} - but treats the list of actions as a single computation. *) + let all ?on_terminate actions = + Bundle.join_after ?on_terminate @@ fun bundle -> + try actions |> List.iter @@ fun action -> Bundle.fork bundle action + with exn -> Bundle.error bundle exn (Printexc.get_raw_backtrace ()) + ]} *) val any : (unit -> unit) list -> unit (** [any actions] starts the actions as separate fibers and waits until one of @@ -385,14 +402,6 @@ module Run : sig {{!Control.Terminate} [Terminate]} after which the remaining started fibers will be canceled and the exception, if any, will be raised. - ⚠️ One of the actions may be run on the current fiber. - - ⚠️ It is not guaranteed that any of the actions in the list are called. In - particular, after the first action returns successfully or after any - action raises an exception other than {{!Control.Terminate} [Terminate]} - or after the main fiber is canceled, the actions that have not yet started - may be skipped entirely. - [any] is roughly equivalent to {[ let any actions = @@ -403,7 +412,7 @@ module Run : sig Bundle.fork bundle @@ fun () -> action (); Bundle.terminate bundle - with Control.Terminate -> () + with exn -> Bundle.error bundle exn (Printexc.get_raw_backtrace ()) ]} but treats the list of actions as a single computation. *) @@ -441,13 +450,47 @@ module Run : sig let value = action () in if Atomic.compare_and_set result None (Some value) then Bundle.terminate bundle - with Control.Terminate -> () + with exn -> + Bundle.error bundle exn (Printexc.get_raw_backtrace ()) end; match Atomic.get result with | None -> raise Control.Terminate | Some value -> value ]} but treats the list of actions as a single computation. *) + + val for_n : ?on_terminate:[ `Raise | `Ignore ] -> int -> (int -> unit) -> unit + (** [for_n n action], when [0 < n], starts [action i] for each integer [i] + from [0] to [n-1] and waits until they all complete. + + [for_n] is roughly equivalent to: + {[ + let for_n ?on_terminate n action = + Bundle.join_after ?on_terminate @@ fun bundle -> + for i = 0 to n - 1 do + Bundle.fork bundle @@ fun () -> action i + done + ]} *) + + val find_opt_n : int -> (int -> 'a option) -> 'a list + (** *) + + module Array : sig + (** Concurrent operations over arrays. *) + + type 'a t = 'a array + (** Type alias for [array]. *) + + val iter : + ?on_terminate:[ `Raise | `Ignore ] -> ('a -> unit) -> 'a t -> unit + (** [iter action array] starts [action array.(i)] for each index of the + [array] and waits until they all complete. *) + + val map : ('a -> 'b) -> 'a t -> 'b t + (** [map fn array] starts [fn array.(i)] for each index of the [array], + waits until they all complete, and return a new array with the return + values from those calls. *) + end end (** {1 Examples} diff --git a/lib/picos_std.structured/run.ml b/lib/picos_std.structured/run.ml index 853712f7..d7a0a8f9 100644 --- a/lib/picos_std.structured/run.ml +++ b/lib/picos_std.structured/run.ml @@ -37,15 +37,15 @@ let rec spawn (Bundle r as t : Bundle.t) state wrap = function Fiber.spawn fiber (wrap t state main); spawn t state wrap mains -let run actions state wrap = - Bundle.join_after @@ fun (Bundle _ as t : Bundle.t) -> +let run ?on_terminate actions state wrap = + Bundle.join_after ?on_terminate @@ fun (Bundle _ as t : Bundle.t) -> try spawn t state wrap actions with exn -> let bt = Printexc.get_raw_backtrace () in Bundle.decr t; Bundle.error t exn bt -let all actions = run actions () wrap_all +let all ?on_terminate actions = run ?on_terminate actions () wrap_all let any actions = run actions () wrap_any let first_or_terminate actions = @@ -54,3 +54,48 @@ let first_or_terminate actions = match Atomic.get result with | None -> raise Control.Terminate | Some value -> value + +(* *) + +let for_n = For.for_n + +let find_opt_n n fn = + let results = Atomic.make [] in + begin + match + For.for_n ~on_terminate:`Raise n @@ fun i -> + match fn i with + | None -> () + | Some v -> + let backoff = ref Backoff.default in + while + let before = Atomic.get results in + let after = v :: before in + not (Atomic.compare_and_set results before after) + do + backoff := Backoff.once !backoff + done; + raise_notrace Control.Terminate + with + | () -> () + | exception Control.Terminate -> () + end; + Atomic.get results + +module Array = struct + type 'a t = 'a array + + let iter ?on_terminate action xs = + for_n ?on_terminate (Array.length xs) @@ fun i -> + action (Array.unsafe_get xs i) + + let[@inline never] map fn xs = + let n = Array.length xs in + if n = 0 then [||] + else + let ys = Array.make n (Obj.magic ()) in + for_n ~on_terminate:`Raise n (fun i -> + Array.unsafe_set ys i (fn (Array.unsafe_get xs i))); + if Obj.double_tag != Obj.tag (Obj.repr (Array.unsafe_get ys 0)) then ys + else Array.map Fun.id ys +end diff --git a/test/dune b/test/dune index dd98e136..13fc892a 100644 --- a/test/dune +++ b/test/dune @@ -257,7 +257,10 @@ (modules test_structured) (libraries alcotest + backoff + multicore-magic picos + picos.domain picos_aux.mpscq picos_std.finally picos_std.structured diff --git a/test/test_structured.ml b/test/test_structured.ml index 27a2c17c..eaa51845 100644 --- a/test/test_structured.ml +++ b/test/test_structured.ml @@ -3,9 +3,10 @@ open Picos_std_finally open Picos_std_structured open Picos_std_sync module Mpscq = Picos_aux_mpscq +module Atomic_array = Multicore_magic.Atomic_array (** Helper to check that computation is restored *) -let check join_after ?callstack ?on_return scope = +let check join_after ?callstack ?on_return ?on_terminate scope = let open Picos in let fiber = Fiber.current () in let before = Fiber.get_computation fiber in @@ -14,7 +15,7 @@ let check join_after ?callstack ?on_return scope = assert (before == after) in lastly check_computation_was_scoped @@ fun () -> - join_after ?callstack ?on_return @@ fun bundle -> + join_after ?callstack ?on_return ?on_terminate @@ fun bundle -> let during = Fiber.get_computation fiber in assert (before != during); scope bundle @@ -219,7 +220,10 @@ let test_any_and_all_returns () = |> List.iter @@ fun n_terminates -> [ 0; 1; 2 ] |> List.iter @@ fun n_incr -> - [ (Run.all, n_incr, n_incr); (Run.any, Int.min 1 n_incr, n_incr) ] + [ + (Run.all ?on_terminate:None, n_incr, n_incr); + (Run.any, Int.min 1 n_incr, n_incr); + ] |> List.iter @@ fun (run_op, min, max) -> Test_scheduler.run ~max_domains:(n_terminates + n_incr + 1) @@ fun () -> @@ -258,6 +262,52 @@ let test_race_any () = (* This is non-deterministic and may need to changed if flaky *) assert (Atomic.get winner = 1) +let test_for_n_basic () = + Test_scheduler.run ~max_domains:(Picos_domain.recommended_domain_count ()) + @@ fun () -> + [ `Ignore; `Raise ] + |> List.iter @@ fun on_terminate -> + for n = 0 to 128 do + let elems = Atomic_array.make n 0 in + let incremented = Atomic.make 0 in + let terminated = Atomic.make 0 in + match + Run.for_n ~on_terminate n @@ fun i -> + if Random.bool () then Control.yield (); + if Random.int n = i then begin + Atomic.incr terminated; + raise Control.Terminate + end; + Atomic.incr incremented; + while + let before = Atomic_array.unsafe_fenceless_get elems i in + let after = before + 1 in + not (Atomic_array.unsafe_compare_and_set elems i before after) + do + Backoff.once Backoff.default |> ignore + done + with + | () -> + if on_terminate != `Ignore then begin + assert (0 = Atomic.get terminated); + assert (n = Atomic.get incremented) + end; + for i = 0 to n - 1 do + let n = Atomic_array.unsafe_fenceless_get elems i in + assert (0 <= n && n <= 1); + if n = 0 then Atomic.decr terminated else Atomic.decr incremented + done; + assert (0 = Atomic.get terminated); + assert (0 = Atomic.get incremented) + | exception Control.Terminate -> + assert (on_terminate == `Raise); + assert (1 <= Atomic.get terminated); + for i = 0 to n - 1 do + let n = Atomic_array.unsafe_fenceless_get elems i in + assert (0 <= n && n <= 1) + done + done + let () = [ ( "Bundle", @@ -286,6 +336,7 @@ let () = Alcotest.test_case "any and all errors" `Quick test_any_and_all_errors; Alcotest.test_case "any and all returns" `Quick test_any_and_all_returns; Alcotest.test_case "race any" `Quick test_race_any; + Alcotest.test_case "for_n basic" `Quick test_for_n_basic; ] ); ] |> Alcotest.run "Picos_structured"