From 5f8a33b965a6c728d13ac4876663785a89f73064 Mon Sep 17 00:00:00 2001 From: Aspen Smith Date: Wed, 26 Feb 2025 14:04:31 -0500 Subject: [PATCH] Define inlined, more-dumb versions of Mod_bounds ops (#3605) * Define inlined, more-dumb versions of Mod_bounds ops The mod-bounds ops that used polymorphic functions and first-class modules were not only a little bit over-abstracted for some peoples' taste, they actually had noticeably worse performance, due to suboptimal inlining behavior (and allocation!) of the `Accent_lattice` FCM. This changes them to be dumber and more repetitive, but more direct, which also gets us a few percent performance win. * Delete Axis_collection entirely Let's inhabit one side of this architectural fence entirely, rather than straddling it. * define equal functions for lattices This sadly doesn't show up on a benchmark, but the profiler shows us spending a fair bit less time here and the generated assembly looks way nicer. Plus it'll probably vectorize much more nicely, once we have a vectorizer. --- typing/jkind.ml | 319 +++++++++++++++++++++++++----------------- typing/jkind_axis.ml | 234 ------------------------------- typing/jkind_axis.mli | 142 ------------------- typing/mode.ml | 99 +++++++++++-- typing/mode_intf.mli | 4 + typing/typemode.ml | 70 ++++++--- typing/typemode.mli | 22 ++- typing/types.ml | 53 ++++++- typing/types.mli | 35 ++++- 9 files changed, 440 insertions(+), 538 deletions(-) diff --git a/typing/jkind.ml b/typing/jkind.ml index c175c51c18d..01222d63e8f 100644 --- a/typing/jkind.ml +++ b/typing/jkind.ml @@ -48,13 +48,13 @@ module Sub_result = struct | Less | Not_le of Sub_failure_reason.t Nonempty_list.t - let of_le_result ~failure_reason (le_result : Misc.Le_result.t) = + let[@inline] of_le_result ~failure_reason (le_result : Misc.Le_result.t) = match le_result with | Less -> Less | Equal -> Equal | Not_le -> Not_le (failure_reason ()) - let combine sr1 sr2 = + let[@inline] combine sr1 sr2 = match sr1, sr2 with | Equal, Equal -> Equal | Equal, Less | Less, Equal | Less, Less -> Less @@ -383,101 +383,152 @@ let raise ~loc err = raise (Error.User_error (loc, err)) module Mod_bounds = struct include Types.Jkind_mod_bounds - - let debug_print ppf - { locality; - linearity; - uniqueness; - portability; - contention; - yielding; - externality; - nullability - } = + module Locality = Mode.Locality.Const + module Linearity = Mode.Linearity.Const + module Uniqueness = Mode.Uniqueness.Const_op + module Portability = Mode.Portability.Const + module Contention = Mode.Contention.Const_op + module Yielding = Mode.Yielding.Const + + let debug_print ppf t = Format.fprintf ppf "@[{ locality = %a;@ linearity = %a;@ uniqueness = %a;@ portability = \ %a;@ contention = %a;@ yielding = %a;@ externality = %a;@ nullability = \ %a }@]" - Mode.Locality.Const.print locality Mode.Linearity.Const.print linearity - Mode.Uniqueness.Const.print uniqueness Mode.Portability.Const.print - portability Mode.Contention.Const.print contention - Mode.Yielding.Const.print yielding Externality.print externality - Nullability.print nullability + Locality.print (locality t) Linearity.print (linearity t) Uniqueness.print + (uniqueness t) Portability.print (portability t) Contention.print + (contention t) Yielding.print (yielding t) Externality.print + (externality t) Nullability.print (nullability t) let min = - Create.f - { f = - (fun (type axis) ~(axis : axis Axis.t) -> - let (module Bound_ops) = Axis.get axis in - Bound_ops.min) - } + create ~locality:Locality.min ~linearity:Linearity.min + ~uniqueness:Uniqueness.min ~portability:Portability.min + ~contention:Contention.min ~yielding:Yielding.min + ~externality:Externality.min ~nullability:Nullability.min let max = - Create.f - { f = - (fun (type axis) ~(axis : axis Axis.t) -> - let (module Bound_ops) = Axis.get axis in - Bound_ops.max) - } - - let simple ~locality ~linearity ~uniqueness ~portability ~contention ~yielding - ~externality ~nullability = - { locality; - linearity; - uniqueness; - portability; - contention; - yielding; - externality; - nullability - } - - let join = - Map2.f - { f = - (fun (type axis) ~(axis : axis Axis.t) -> - let (module Bound_ops) = Axis.get axis in - Bound_ops.join) - } - - let meet = - Map2.f - { f = - (fun (type axis) ~(axis : axis Axis.t) -> - let (module Bound_ops) = Axis.get axis in - Bound_ops.meet) - } - - let less_or_equal = - Fold2.f - { f = - (fun (type axis) ~(axis : axis Axis.t) b1 b2 -> - let (module Bound_ops) = Axis.get axis in - Sub_result.of_le_result (Bound_ops.less_or_equal b1 b2) - ~failure_reason:(fun () -> [Axis_disagreement (Pack axis)])) - } - ~combine:Sub_result.combine - - let equal = - Fold2.f - { f = - (fun (type axis) ~(axis : axis Axis.t) -> - let (module Bound_ops) = Axis.get axis in - Bound_ops.equal) - } - ~combine:( && ) + create ~locality:Locality.max ~linearity:Linearity.max + ~uniqueness:Uniqueness.max ~portability:Portability.max + ~contention:Contention.max ~yielding:Yielding.max + ~externality:Externality.max ~nullability:Nullability.max + + let join t1 t2 = + let locality = Locality.join (locality t1) (locality t2) in + let linearity = Linearity.join (linearity t1) (linearity t2) in + let uniqueness = Uniqueness.join (uniqueness t1) (uniqueness t2) in + let portability = Portability.join (portability t1) (portability t2) in + let contention = Contention.join (contention t1) (contention t2) in + let yielding = Yielding.join (yielding t1) (yielding t2) in + let externality = Externality.join (externality t1) (externality t2) in + let nullability = Nullability.join (nullability t1) (nullability t2) in + create ~locality ~linearity ~uniqueness ~portability ~contention ~yielding + ~externality ~nullability + + let meet t1 t2 = + let locality = Locality.meet (locality t1) (locality t2) in + let linearity = Linearity.meet (linearity t1) (linearity t2) in + let uniqueness = Uniqueness.meet (uniqueness t1) (uniqueness t2) in + let portability = Portability.meet (portability t1) (portability t2) in + let contention = Contention.meet (contention t1) (contention t2) in + let yielding = Yielding.meet (yielding t1) (yielding t2) in + let externality = Externality.meet (externality t1) (externality t2) in + let nullability = Nullability.meet (nullability t1) (nullability t2) in + create ~locality ~linearity ~uniqueness ~portability ~contention ~yielding + ~externality ~nullability + + let less_or_equal t1 t2 = + let[@inline] axis_less_or_equal ~le ~axis a b : Sub_result.t = + match le a b, le b a with + | true, true -> Equal + | true, false -> Less + | false, _ -> Not_le [Axis_disagreement axis] + in + Sub_result.combine + (axis_less_or_equal ~le:Locality.le + ~axis:(Pack (Modal (Comonadic Areality))) (locality t1) (locality t2)) + @@ Sub_result.combine + (axis_less_or_equal ~le:Uniqueness.le + ~axis:(Pack (Modal (Monadic Uniqueness))) (uniqueness t1) + (uniqueness t2)) + @@ Sub_result.combine + (axis_less_or_equal ~le:Linearity.le + ~axis:(Pack (Modal (Comonadic Linearity))) (linearity t1) + (linearity t2)) + @@ Sub_result.combine + (axis_less_or_equal ~le:Contention.le + ~axis:(Pack (Modal (Monadic Contention))) (contention t1) + (contention t2)) + @@ Sub_result.combine + (axis_less_or_equal ~le:Portability.le + ~axis:(Pack (Modal (Comonadic Portability))) (portability t1) + (portability t2)) + @@ Sub_result.combine + (axis_less_or_equal ~le:Yielding.le + ~axis:(Pack (Modal (Comonadic Yielding))) (yielding t1) + (yielding t2)) + @@ Sub_result.combine + (axis_less_or_equal ~le:Externality.le + ~axis:(Pack (Nonmodal Externality)) (externality t1) + (externality t2)) + @@ axis_less_or_equal ~le:Nullability.le ~axis:(Pack (Nonmodal Nullability)) + (nullability t1) (nullability t2) + + let equal t1 t2 = + Locality.equal (locality t1) (locality t2) + && Linearity.equal (linearity t1) (linearity t2) + && Uniqueness.equal (uniqueness t1) (uniqueness t2) + && Portability.equal (portability t1) (portability t2) + && Contention.equal (contention t1) (contention t2) + && Yielding.equal (yielding t1) (yielding t2) + && Externality.equal (externality t1) (externality t2) + && Nullability.equal (nullability t1) (nullability t2) + + let[@inline] get (type a) ~(axis : a Axis.t) t : a = + match axis with + | Modal (Monadic Uniqueness) -> uniqueness t + | Modal (Comonadic Areality) -> locality t + | Modal (Monadic Contention) -> contention t + | Modal (Comonadic Linearity) -> linearity t + | Modal (Comonadic Portability) -> portability t + | Modal (Comonadic Yielding) -> yielding t + | Nonmodal Externality -> externality t + | Nonmodal Nullability -> nullability t (** Get all axes that are set to max *) let get_max_axes t = - Axis_set.create ~f:(fun ~axis:(Pack axis) -> - let (module Axis_ops) = Axis.get axis in - let bound = get ~axis t in - Axis_ops.le Axis_ops.max bound) + let[@inline] add_if b ax axis_set = + if b then Axis_set.add axis_set ax else axis_set + in + Axis_set.empty + |> add_if + (Locality.le Locality.max (locality t)) + (Modal (Comonadic Areality)) + |> add_if + (Linearity.le Linearity.max (linearity t)) + (Modal (Comonadic Linearity)) + |> add_if + (Uniqueness.le Uniqueness.max (uniqueness t)) + (Modal (Monadic Uniqueness)) + |> add_if + (Portability.le Portability.max (portability t)) + (Modal (Comonadic Portability)) + |> add_if + (Contention.le Contention.max (contention t)) + (Modal (Monadic Contention)) + |> add_if + (Yielding.le Yielding.max (yielding t)) + (Modal (Comonadic Yielding)) + |> add_if + (Externality.le Externality.max (externality t)) + (Nonmodal Externality) + |> add_if + (Nullability.le Nullability.max (nullability t)) + (Nonmodal Nullability) let for_arrow = - simple ~linearity:Linearity.Const.max ~locality:Locality.Const.max - ~uniqueness:Uniqueness.Const_op.min ~portability:Portability.Const.max - ~contention:Contention.Const_op.min ~yielding:Yielding.Const.max + create ~linearity:Linearity.max ~locality:Locality.max + ~uniqueness:Uniqueness.min ~portability:Portability.max + ~contention:Contention.min ~yielding:Yielding.max ~externality:Externality.max ~nullability:Nullability.Non_null end @@ -828,9 +879,7 @@ module Layout_and_axes = struct (type_expr * With_bounds_type_info.t) list -> Mod_bounds.t * (l2 * r2) with_bounds * Fuel_status.t = function (* early cutoff *) - | _ - when Sub_result.is_le - (Mod_bounds.less_or_equal Mod_bounds.max bounds_so_far) -> + | _ when Mod_bounds.equal Mod_bounds.max bounds_so_far -> (* CR layouts v2.8: we can do better by early-terminating on a per-axis basis *) bounds_so_far, No_with_bounds, Sufficient_fuel | [] -> bounds_so_far, No_with_bounds, ctl.fuel_status @@ -856,16 +905,24 @@ module Layout_and_axes = struct loop ctl bounds_so_far bs | false -> ( let join_bounds b1 b2 ~relevant_axes = - Mod_bounds.Map2.f - { f = - (fun (type a) ~(axis : a Axis.t) b1 b2 -> - if Axis_set.mem relevant_axes axis - then - let (module Bound_ops) = Axis.get axis in - Bound_ops.join b1 b2 - else b1) - } - b1 b2 + let value_for_axis (type a) ~(axis : a Axis.t) : a = + if Axis_set.mem relevant_axes axis + then + let (module Bound_ops) = Axis.get axis in + Bound_ops.join (Mod_bounds.get ~axis b1) + (Mod_bounds.get ~axis b2) + else Mod_bounds.get ~axis b1 + in + Mod_bounds.create + ~locality:(value_for_axis ~axis:(Modal (Comonadic Areality))) + ~linearity:(value_for_axis ~axis:(Modal (Comonadic Linearity))) + ~uniqueness:(value_for_axis ~axis:(Modal (Monadic Uniqueness))) + ~portability: + (value_for_axis ~axis:(Modal (Comonadic Portability))) + ~contention:(value_for_axis ~axis:(Modal (Monadic Contention))) + ~yielding:(value_for_axis ~axis:(Modal (Comonadic Yielding))) + ~externality:(value_for_axis ~axis:(Nonmodal Externality)) + ~nullability:(value_for_axis ~axis:(Nonmodal Nullability)) in let found_jkind_for_ty new_ctl b_upper_bounds b_with_bounds quality : Mod_bounds.t * (l2 * r2) with_bounds * Fuel_status.t = @@ -1079,9 +1136,10 @@ module Const = struct let mk_jkind ~mode_crossing ~nullability (layout : Layout.Const.t) = let mod_bounds = - match mode_crossing with - | true -> { Mod_bounds.min with nullability } - | false -> { Mod_bounds.max with nullability } + (match mode_crossing with + | true -> Mod_bounds.min + | false -> Mod_bounds.max) + |> Mod_bounds.set_nullability nullability in { layout; mod_bounds; with_bounds = No_with_bounds } @@ -1110,7 +1168,7 @@ module Const = struct { jkind = { layout = Base Value; mod_bounds = - Mod_bounds.simple ~locality:Locality.Const.max + Mod_bounds.create ~locality:Locality.Const.max ~linearity:Linearity.Const.min ~portability:Portability.Const.min ~yielding:Yielding.Const.min ~uniqueness:Uniqueness.Const_op.max @@ -1125,7 +1183,7 @@ module Const = struct { jkind = { layout = Base Value; mod_bounds = - Mod_bounds.simple ~locality:Locality.Const.max + Mod_bounds.create ~locality:Locality.Const.max ~linearity:Linearity.Const.min ~portability:Portability.Const.min ~yielding:Yielding.Const.min ~contention:Contention.Const_op.max @@ -1181,9 +1239,8 @@ module Const = struct { jkind = { immediate.jkind with mod_bounds = - { immediate.jkind.mod_bounds with - externality = Externality.External64 - } + Mod_bounds.set_externality Externality.External64 + immediate.jkind.mod_bounds }; name = "immediate64" } @@ -1403,9 +1460,8 @@ module Const = struct { jkind = { layout = jkind.layout; mod_bounds = - { Mod_bounds.max with - nullability = Nullability.Non_null - }; + Mod_bounds.set_nullability Nullability.Non_null + Mod_bounds.max; with_bounds = No_with_bounds }; name = Layout.Const.to_string jkind.layout @@ -1504,19 +1560,26 @@ module Const = struct (* for each mode, lower the corresponding modal bound to be that mode *) let parsed_modifiers = Typemode.transl_modifier_annots modifiers in let mod_bounds = - Mod_bounds.Create.f - { f = - (fun (type a) ~(axis : a Axis.t) -> - let (module A) = Axis.get axis in - let parsed_modifier = - Typemode.Transled_modifiers.get ~axis parsed_modifiers - in - let base_bound = Mod_bounds.get ~axis base.mod_bounds in - match parsed_modifier, base_bound with - | None, base_bound -> base_bound - | Some parsed_modifier, base_modifier -> - A.meet base_modifier parsed_modifier.txt) - } + let value_for_axis (type a) ~(axis : a Axis.t) : a = + let (module A) = Axis.get axis in + let parsed_modifier = + Typemode.Transled_modifiers.get ~axis parsed_modifiers + in + let base_bound = Mod_bounds.get ~axis base.mod_bounds in + match parsed_modifier, base_bound with + | None, base_modifier -> base_modifier + | Some parsed_modifier, base_modifier -> + A.meet base_modifier parsed_modifier.txt + in + Mod_bounds.create + ~locality:(value_for_axis ~axis:(Modal (Comonadic Areality))) + ~linearity:(value_for_axis ~axis:(Modal (Comonadic Linearity))) + ~uniqueness:(value_for_axis ~axis:(Modal (Monadic Uniqueness))) + ~portability:(value_for_axis ~axis:(Modal (Comonadic Portability))) + ~contention:(value_for_axis ~axis:(Modal (Monadic Contention))) + ~yielding:(value_for_axis ~axis:(Modal (Comonadic Yielding))) + ~externality:(value_for_axis ~axis:(Nonmodal Externality)) + ~nullability:(value_for_axis ~axis:(Nonmodal Nullability)) in { layout = base.layout; mod_bounds; with_bounds = No_with_bounds } | Product ts -> @@ -1551,7 +1614,7 @@ module Const = struct let get_required_layouts_level (_context : 'd Context_with_transl.t) (jkind : 'd t) = let rec scan_layout (l : Layout.Const.t) : Language_extension.maturity = - match l, jkind.mod_bounds.nullability with + match l, Mod_bounds.nullability jkind.mod_bounds with | (Base (Float64 | Float32 | Word | Bits32 | Bits64 | Vec128) | Any), _ | Base Value, Non_null | Base Value, Maybe_null -> @@ -1605,7 +1668,9 @@ module Jkind_desc = struct let of_const t = Layout_and_axes.map Layout.of_const t let add_nullability_crossing t = - { t with mod_bounds = { t.mod_bounds with nullability = Nullability.min } } + { t with + mod_bounds = Mod_bounds.set_nullability Nullability.min t.mod_bounds + } let unsafely_set_mod_bounds t ~from = { t with mod_bounds = from.mod_bounds; with_bounds = No_with_bounds } @@ -1668,7 +1733,7 @@ module Jkind_desc = struct let layout, sort = Layout.of_new_sort_var () in ( { layout; mod_bounds = - { Mod_bounds.max with nullability = nullability_upper_bound }; + Mod_bounds.set_nullability nullability_upper_bound Mod_bounds.max; with_bounds = No_with_bounds }, sort ) @@ -2062,7 +2127,7 @@ let for_object = fresh_jkind { layout = Sort (Base Value); mod_bounds = - Mod_bounds.simple ~linearity ~locality ~uniqueness ~portability + Mod_bounds.create ~linearity ~locality ~uniqueness ~portability ~contention ~yielding ~externality:Externality.max ~nullability:Non_null; with_bounds = No_with_bounds @@ -2174,7 +2239,7 @@ let set_externality_upper_bound jk externality_upper_bound = jkind = { jk.jkind with mod_bounds = - { jk.jkind.mod_bounds with externality = externality_upper_bound } + Mod_bounds.set_externality externality_upper_bound jk.jkind.mod_bounds } } @@ -3028,7 +3093,7 @@ let is_value_for_printing ~ignore_null { jkind; _ } = then { value with mod_bounds = - { value.mod_bounds with nullability = Nullability.Maybe_null } + Mod_bounds.set_nullability Nullability.Maybe_null value.mod_bounds } :: values else values diff --git a/typing/jkind_axis.ml b/typing/jkind_axis.ml index 5d8b96ee9e9..674a35ffaea 100644 --- a/typing/jkind_axis.ml +++ b/typing/jkind_axis.ml @@ -172,240 +172,6 @@ module Axis = struct | Nonmodal Nullability -> false end -module Axis_collection = struct - module Indexed_gen (T : Misc.T2) = struct - type 'a t_poly = - { locality : (Mode.Locality.Const.t, 'a) T.t; - linearity : (Mode.Linearity.Const.t, 'a) T.t; - uniqueness : (Mode.Uniqueness.Const.t, 'a) T.t; - portability : (Mode.Portability.Const.t, 'a) T.t; - contention : (Mode.Contention.Const.t, 'a) T.t; - yielding : (Mode.Yielding.Const.t, 'a) T.t; - externality : (Externality.t, 'a) T.t; - nullability : (Nullability.t, 'a) T.t - } - - type 'a t = 'a t_poly - - let get (type a) ~(axis : a Axis.t) (t : 'b t) : (a, 'b) T.t = - match axis with - | Modal (Comonadic Areality) -> t.locality - | Modal (Comonadic Linearity) -> t.linearity - | Modal (Monadic Uniqueness) -> t.uniqueness - | Modal (Comonadic Portability) -> t.portability - | Modal (Monadic Contention) -> t.contention - | Modal (Comonadic Yielding) -> t.yielding - | Nonmodal Externality -> t.externality - | Nonmodal Nullability -> t.nullability - - let set (type a) ~(axis : a Axis.t) (t : 'b t) (value : (a, 'b) T.t) = - match axis with - | Modal (Comonadic Areality) -> { t with locality = value } - | Modal (Comonadic Linearity) -> { t with linearity = value } - | Modal (Monadic Uniqueness) -> { t with uniqueness = value } - | Modal (Comonadic Portability) -> { t with portability = value } - | Modal (Monadic Contention) -> { t with contention = value } - | Modal (Comonadic Yielding) -> { t with yielding = value } - | Nonmodal Externality -> { t with externality = value } - | Nonmodal Nullability -> { t with nullability = value } - - (* Since we don't have polymorphic parameters, use a record to pass the - polymorphic function *) - module Create = struct - module Monadic (M : Misc.Stdlib.Monad.S) = struct - type 'a f = { f : 'axis. axis:'axis Axis.t -> ('axis, 'a) T.t M.t } - [@@unboxed] - - let[@inline] f { f } = - let open M.Syntax in - let* locality = f ~axis:Axis.(Modal (Comonadic Areality)) in - let* uniqueness = f ~axis:Axis.(Modal (Monadic Uniqueness)) in - let* linearity = f ~axis:Axis.(Modal (Comonadic Linearity)) in - let* contention = f ~axis:Axis.(Modal (Monadic Contention)) in - let* portability = f ~axis:Axis.(Modal (Comonadic Portability)) in - let* yielding = f ~axis:Axis.(Modal (Comonadic Yielding)) in - let* externality = f ~axis:Axis.(Nonmodal Externality) in - let* nullability = f ~axis:Axis.(Nonmodal Nullability) in - M.return - { locality; - uniqueness; - linearity; - contention; - portability; - yielding; - externality; - nullability - } - end - [@@inline] - - module Monadic_identity = Monadic (Misc.Stdlib.Monad.Identity) - - type 'a f = 'a Monadic_identity.f - - let[@inline] f f = Monadic_identity.f f - end - - module Map = struct - module Monadic (M : Misc.Stdlib.Monad.S) = struct - type ('a, 'b) f = - { f : - 'axis. axis:'axis Axis.t -> ('axis, 'a) T.t -> ('axis, 'b) T.t M.t - } - [@@unboxed] - - module Create = Create.Monadic (M) - - let[@inline] f { f } bounds = - Create.f { f = (fun ~axis -> f ~axis (get ~axis bounds)) } - end - [@@inline] - - module Monadic_identity = Monadic (Misc.Stdlib.Monad.Identity) - - type ('a, 'b) f = ('a, 'b) Monadic_identity.f - - let[@inline] f f bounds = Monadic_identity.f f bounds - end - - module Iter = struct - type 'a f = { f : 'axis. axis:'axis Axis.t -> ('axis, 'a) T.t -> unit } - [@@unboxed] - - let[@inline] f { f } - { locality; - linearity; - uniqueness; - portability; - contention; - yielding; - externality; - nullability - } = - f ~axis:Axis.(Modal (Comonadic Areality)) locality; - f ~axis:Axis.(Modal (Monadic Uniqueness)) uniqueness; - f ~axis:Axis.(Modal (Comonadic Linearity)) linearity; - f ~axis:Axis.(Modal (Monadic Contention)) contention; - f ~axis:Axis.(Modal (Comonadic Portability)) portability; - f ~axis:Axis.(Modal (Comonadic Yielding)) yielding; - f ~axis:Axis.(Nonmodal Externality) externality; - f ~axis:Axis.(Nonmodal Nullability) nullability - end - - module Map2 = struct - module Monadic (M : Misc.Stdlib.Monad.S) = struct - type ('a, 'b, 'c) f = - { f : - 'axis. - axis:'axis Axis.t -> - ('axis, 'a) T.t -> - ('axis, 'b) T.t -> - ('axis, 'c) T.t M.t - } - [@@unboxed] - - module Create = Create.Monadic (M) - - let[@inline] f { f } bounds1 bounds2 = - Create.f - { f = (fun ~axis -> f ~axis (get ~axis bounds1) (get ~axis bounds2)) - } - end - [@@inline] - - module Monadic_identity = Monadic (Misc.Stdlib.Monad.Identity) - - type ('a, 'b, 'c) f = ('a, 'b, 'c) Monadic_identity.f - - let[@inline] f f bounds1 bounds2 = Monadic_identity.f f bounds1 bounds2 - end - - module Fold = struct - type ('a, 'r) f = - { f : 'axis. axis:'axis Axis.t -> ('axis, 'a) T.t -> 'r } - [@@unboxed] - - let[@inline] f { f } - { locality; - linearity; - uniqueness; - portability; - contention; - yielding; - externality; - nullability - } ~combine = - combine (f ~axis:Axis.(Modal (Comonadic Areality)) locality) - @@ combine (f ~axis:Axis.(Modal (Monadic Uniqueness)) uniqueness) - @@ combine (f ~axis:Axis.(Modal (Comonadic Linearity)) linearity) - @@ combine (f ~axis:Axis.(Modal (Monadic Contention)) contention) - @@ combine (f ~axis:Axis.(Modal (Comonadic Portability)) portability) - @@ combine (f ~axis:Axis.(Modal (Comonadic Yielding)) yielding) - @@ combine (f ~axis:Axis.(Nonmodal Externality) externality) - @@ f ~axis:Axis.(Nonmodal Nullability) nullability - end - - module Fold2 = struct - type ('a, 'b, 'r) f = - { f : - 'axis. axis:'axis Axis.t -> ('axis, 'a) T.t -> ('axis, 'b) T.t -> 'r - } - [@@unboxed] - - let[@inline] f { f } - { locality = loc1; - linearity = lin1; - uniqueness = uni1; - portability = por1; - contention = con1; - yielding = yie1; - externality = ext1; - nullability = nul1 - } - { locality = loc2; - linearity = lin2; - uniqueness = uni2; - portability = por2; - contention = con2; - yielding = yie2; - externality = ext2; - nullability = nul2 - } ~combine = - combine (f ~axis:Axis.(Modal (Comonadic Areality)) loc1 loc2) - @@ combine (f ~axis:Axis.(Modal (Monadic Uniqueness)) uni1 uni2) - @@ combine (f ~axis:Axis.(Modal (Comonadic Linearity)) lin1 lin2) - @@ combine (f ~axis:Axis.(Modal (Monadic Contention)) con1 con2) - @@ combine (f ~axis:Axis.(Modal (Comonadic Portability)) por1 por2) - @@ combine (f ~axis:Axis.(Modal (Comonadic Yielding)) yie1 yie2) - @@ combine (f ~axis:Axis.(Nonmodal Externality) ext1 ext2) - @@ f ~axis:Axis.(Nonmodal Nullability) nul1 nul2 - end - end - - module Indexed (T : Misc.T1) = struct - include Indexed_gen (struct - type ('a, 'b) t = 'a T.t - end) - - type nonrec t = unit t - end - - module Identity = Indexed (Misc.Stdlib.Monad.Identity) - - include Indexed_gen (struct - type ('a, 'b) t = 'b - end) - - let create ~f = Create.f { f = (fun ~axis -> f ~axis:(Axis.Pack axis)) } - - let map ~f t = Map.f { f = (fun ~axis:_ x -> f x) } t - - let mapi ~f t = Map.f { f = (fun ~axis x -> f ~axis:(Axis.Pack axis) x) } t - - let fold ~f ~combine t = - Fold.f { f = (fun ~axis acc -> f ~axis:(Axis.Pack axis) acc) } t ~combine -end - module Axis_set = struct (* This could be [bool Axis_collection.t], but instead we represent it as a bitfield for performance (this matters, since these are hammered on quite a bit during with-bound diff --git a/typing/jkind_axis.mli b/typing/jkind_axis.mli index 9d5df621872..8999effa4c5 100644 --- a/typing/jkind_axis.mli +++ b/typing/jkind_axis.mli @@ -68,148 +68,6 @@ module Axis : sig val is_modal : _ t -> bool end -(** A collection with one item for each jkind axis *) -module Axis_collection : sig - module type S_gen := sig - type ('a, 'b) u - - (* This is t_poly instead of t because in some instantiations of this signature, u - ignores its second parameter. In order to avoid needed to apply a useless type - parameter for those instantiations, we define [type t = unit t_poly] in them. In - instantiations where the polymorphism is actually used, we define - [type 'a t = 'a t_poly] *) - type 'a t_poly = - { locality : (Mode.Locality.Const.t, 'a) u; - linearity : (Mode.Linearity.Const.t, 'a) u; - uniqueness : (Mode.Uniqueness.Const.t, 'a) u; - portability : (Mode.Portability.Const.t, 'a) u; - contention : (Mode.Contention.Const.t, 'a) u; - yielding : (Mode.Yielding.Const.t, 'a) u; - externality : (Externality.t, 'a) u; - nullability : (Nullability.t, 'a) u - } - - val get : axis:'a Axis.t -> 'b t_poly -> ('a, 'b) u - - val set : axis:'a Axis.t -> 'b t_poly -> ('a, 'b) u -> 'b t_poly - - (** Create an axis collection by applying the function on each axis *) - module Create : sig - module Monadic (M : Misc.Stdlib.Monad.S) : sig - type 'a f = { f : 'axis. axis:'axis Axis.t -> ('axis, 'a) u M.t } - [@@unboxed] - - val f : 'a f -> 'a t_poly M.t - end - - (** This record type is used to pass a polymorphic function to [create] *) - type 'a f = 'a Monadic(Misc.Stdlib.Monad.Identity).f - - val f : 'a f -> 'a t_poly - end - - (** Map an operation over all the bounds *) - module Map : sig - module Monadic (M : Misc.Stdlib.Monad.S) : sig - type ('a, 'b) f = - { f : 'axis. axis:'axis Axis.t -> ('axis, 'a) u -> ('axis, 'b) u M.t } - [@@unboxed] - - val f : ('a, 'b) f -> 'a t_poly -> 'b t_poly M.t - end - - type ('a, 'b) f = ('a, 'b) Monadic(Misc.Stdlib.Monad.Identity).f - - val f : ('a, 'b) f -> 'a t_poly -> 'b t_poly - end - - module Iter : sig - type 'a f = { f : 'axis. axis:'axis Axis.t -> ('axis, 'a) u -> unit } - [@@unboxed] - - val f : 'a f -> 'a t_poly -> unit - end - - (** Map an operation over two sets of bounds *) - module Map2 : sig - module Monadic (M : Misc.Stdlib.Monad.S) : sig - type ('a, 'b, 'c) f = - { f : - 'axis. - axis:'axis Axis.t -> - ('axis, 'a) u -> - ('axis, 'b) u -> - ('axis, 'c) u M.t - } - [@@unboxed] - - val f : ('a, 'b, 'c) f -> 'a t_poly -> 'b t_poly -> 'c t_poly M.t - end - - type ('a, 'b, 'c) f = ('a, 'b, 'c) Monadic(Misc.Stdlib.Monad.Identity).f - - val f : ('a, 'b, 'c) f -> 'a t_poly -> 'b t_poly -> 'c t_poly - end - - (** Fold an operation over the bounds to a summary value *) - module Fold : sig - type ('a, 'r) f = { f : 'axis. axis:'axis Axis.t -> ('axis, 'a) u -> 'r } - [@@unboxed] - - (** [combine] should be commutative and associative. *) - val f : ('a, 'r) f -> 'a t_poly -> combine:('r -> 'r -> 'r) -> 'r - end - - (** Fold an operation over two sets of bounds to a summary value *) - module Fold2 : sig - type ('a, 'b, 'r) f = - { f : 'axis. axis:'axis Axis.t -> ('axis, 'a) u -> ('axis, 'b) u -> 'r } - [@@unboxed] - - (** [combine] should be commutative and associative. *) - val f : - ('a, 'b, 'r) f -> - 'a t_poly -> - 'b t_poly -> - combine:('r -> 'r -> 'r) -> - 'r - end - end - - module type S_poly := sig - include S_gen - - type 'a t = 'a t_poly - end - - module type S_mono := sig - include S_gen - - type t = unit t_poly - end - - module Indexed_gen (T : Misc.T2) : S_poly with type ('a, 'b) u := ('a, 'b) T.t - - module Indexed (T : Misc.T1) : S_mono with type ('a, 'b) u := 'a T.t - - module Identity : S_mono with type ('a, 'b) u := 'a - - include S_poly with type ('a, 'b) u := 'b - - val create : f:(axis:Axis.packed -> 'a) -> 'a t - - val get : axis:'ax Axis.t -> 'a t -> 'a - - val set : axis:'ax Axis.t -> 'a t -> 'a -> 'a t - - val mapi : f:(axis:Axis.packed -> 'a -> 'a) -> 'a t -> 'a t - - val map : f:('a -> 'a) -> 'a t -> 'a t - - val fold : - f:(axis:Axis.packed -> 'a -> 'r) -> combine:('r -> 'r -> 'r) -> 'a t -> 'r -end - module Axis_set : sig (** A set of [Axis.t], represented as a bitfield for efficiency. *) type t [@@immediate] diff --git a/typing/mode.ml b/typing/mode.ml index 24be03f6264..a85d7e24e71 100644 --- a/typing/mode.ml +++ b/typing/mode.ml @@ -51,7 +51,9 @@ module Lattices = struct let legacy = L.legacy - let le a b = L.le b a + let[@inline] le a b = L.le b a + + let equal = L.equal let join = L.meet @@ -111,9 +113,14 @@ module Lattices = struct let legacy = Global - let le a b = + let[@inline] le a b = match a, b with Global, _ | _, Local -> true | Local, Global -> false + let[@inline] equal a b = + match a, b with + | Global, Global | Local, Local -> true + | Global, Local | Local, Global -> false + let join a b = match a, b with | Local, _ | _, Local -> Local @@ -147,6 +154,16 @@ module Lattices = struct let legacy = Global + let[@inline] equal a b = + match a, b with + | Global, Global -> true + | Regional, Regional -> true + | Local, Local -> true + | Global, (Regional | Local) + | Regional, (Global | Local) + | Local, (Global | Regional) -> + false + let join a b = match a, b with | Local, _ | _, Local -> Local @@ -159,7 +176,7 @@ module Lattices = struct | Regional, _ | _, Regional -> Regional | Local, Local -> Local - let le a b = + let[@inline] le a b = match a, b with | Global, _ | _, Local -> true | _, Global | Local, _ -> false @@ -188,11 +205,17 @@ module Lattices = struct let legacy = Aliased - let le a b = + let[@inline] le a b = match a, b with | Unique, _ | _, Aliased -> true | Aliased, Unique -> false + let[@inline] equal a b = + match a, b with + | Unique, Unique -> true + | Aliased, Aliased -> true + | Unique, Aliased | Aliased, Unique -> false + let join a b = match a, b with | Aliased, _ | _, Aliased -> Aliased @@ -225,9 +248,15 @@ module Lattices = struct let legacy = Many - let le a b = + let[@inline] le a b = match a, b with Many, _ | _, Once -> true | Once, Many -> false + let[@inline] equal a b = + match a, b with + | Many, Many -> true + | Once, Once -> true + | Many, Once | Once, Many -> false + let join a b = match a, b with Once, _ | _, Once -> Once | Many, Many -> Many @@ -254,11 +283,17 @@ module Lattices = struct let legacy = Nonportable - let le a b = + let[@inline] le a b = match a, b with | Portable, _ | _, Nonportable -> true | Nonportable, Portable -> false + let[@inline] equal a b = + match a, b with + | Portable, Portable -> true + | Nonportable, Nonportable -> true + | Portable, Nonportable | Nonportable, Portable -> false + let join a b = match a, b with | Nonportable, _ | _, Nonportable -> Nonportable @@ -290,12 +325,22 @@ module Lattices = struct let legacy = Uncontended - let le a b = + let[@inline] le a b = match a, b with | Uncontended, _ | _, Contended -> true | _, Uncontended | Contended, _ -> false | Shared, Shared -> true + let[@inline] equal a b = + match a, b with + | Contended, Contended -> true + | Shared, Shared -> true + | Uncontended, Uncontended -> true + | Contended, (Shared | Uncontended) + | Shared, (Contended | Uncontended) + | Uncontended, (Contended | Shared) -> + false + let join a b = match a, b with | Contended, _ | _, Contended -> Contended @@ -331,11 +376,17 @@ module Lattices = struct let legacy = Unyielding - let le a b = + let[@inline] le a b = match a, b with | Unyielding, _ | _, Yielding -> true | Yielding, Unyielding -> false + let[@inline] equal a b = + match a, b with + | Yielding, Yielding -> true + | Unyielding, Unyielding -> true + | Yielding, Unyielding | Unyielding, Yielding -> false + let join a b = match a, b with | Yielding, _ | _, Yielding -> Yielding @@ -381,6 +432,12 @@ module Lattices = struct Uniqueness.le uniqueness1 uniqueness2 && Contention.le contention1 contention2 + let equal m1 m2 = + let { uniqueness = uniqueness1; contention = contention1 } = m1 in + let { uniqueness = uniqueness2; contention = contention2 } = m2 in + Uniqueness.equal uniqueness1 uniqueness2 + && Contention.equal contention1 contention2 + let join m1 m2 = let uniqueness = Uniqueness.join m1.uniqueness m2.uniqueness in let contention = Contention.join m1.contention m2.contention in @@ -457,6 +514,26 @@ module Lattices = struct && Portability.le portability1 portability2 && Yielding.le yielding1 yielding2 + let equal m1 m2 = + let { areality = areality1; + linearity = linearity1; + portability = portability1; + yielding = yielding1 + } = + m1 + in + let { areality = areality2; + linearity = linearity2; + portability = portability2; + yielding = yielding2 + } = + m2 + in + Areality.equal areality1 areality2 + && Linearity.equal linearity1 linearity2 + && Portability.equal portability1 portability2 + && Yielding.equal yielding1 yielding2 + let join m1 m2 = let areality = Areality.join m1.areality m2.areality in let linearity = Linearity.join m1.linearity m2.linearity in @@ -2011,6 +2088,12 @@ module Value_with (Areality : Areality) = struct let m1 = split m1 in Comonadic.le m0.comonadic m1.comonadic && Monadic.le m0.monadic m1.monadic + let equal m0 m1 = + let m0 = split m0 in + let m1 = split m1 in + Comonadic.equal m0.comonadic m1.comonadic + && Monadic.equal m0.monadic m1.monadic + let print ppf m = let { monadic; comonadic } = split m in Format.fprintf ppf "%a,%a" Comonadic.print comonadic Monadic.print monadic diff --git a/typing/mode_intf.mli b/typing/mode_intf.mli index 34f5445c7af..67c15457928 100644 --- a/typing/mode_intf.mli +++ b/typing/mode_intf.mli @@ -28,6 +28,10 @@ module type Lattice = sig val le : t -> t -> bool + (** [equal a b] is equivalent to [le a b && le b a], but defined separately for + performance reasons *) + val equal : t -> t -> bool + val join : t -> t -> t val meet : t -> t -> t diff --git a/typing/typemode.ml b/typing/typemode.ml index b8ec368407d..617ea9cb6bd 100644 --- a/typing/typemode.ml +++ b/typing/typemode.ml @@ -84,14 +84,52 @@ let transl_annot (type m) ~(annot_type : m annot_type) ~required_mode_maturity let unpack_mode_annot { txt = Parsetree.Mode s; loc } = { txt = s; loc } -module Transled_modifier = struct - type 'a t = 'a Location.loc option +module Transled_modifiers = struct + type t = + { locality : Mode.Locality.Const.t Location.loc option; + linearity : Mode.Linearity.Const.t Location.loc option; + uniqueness : Mode.Uniqueness.Const.t Location.loc option; + portability : Mode.Portability.Const.t Location.loc option; + contention : Mode.Contention.Const.t Location.loc option; + yielding : Mode.Yielding.Const.t Location.loc option; + externality : Jkind_axis.Externality.t Location.loc option; + nullability : Jkind_axis.Nullability.t Location.loc option + } - let drop_loc modifier = Option.map Location.get_txt modifier -end + let empty = + { locality = None; + linearity = None; + uniqueness = None; + portability = None; + contention = None; + yielding = None; + externality = None; + nullability = None + } + + let get (type a) ~(axis : a Axis.t) (t : t) : a Location.loc option = + match axis with + | Modal (Comonadic Areality) -> t.locality + | Modal (Comonadic Linearity) -> t.linearity + | Modal (Monadic Uniqueness) -> t.uniqueness + | Modal (Comonadic Portability) -> t.portability + | Modal (Monadic Contention) -> t.contention + | Modal (Comonadic Yielding) -> t.yielding + | Nonmodal Externality -> t.externality + | Nonmodal Nullability -> t.nullability -module Transled_modifiers = - Jkind_axis.Axis_collection.Indexed (Transled_modifier) + let set (type a) ~(axis : a Axis.t) (t : t) (value : a Location.loc option) : + t = + match axis with + | Modal (Comonadic Areality) -> { t with locality = value } + | Modal (Comonadic Linearity) -> { t with linearity = value } + | Modal (Monadic Uniqueness) -> { t with uniqueness = value } + | Modal (Comonadic Portability) -> { t with portability = value } + | Modal (Monadic Contention) -> { t with contention = value } + | Modal (Comonadic Yielding) -> { t with yielding = value } + | Nonmodal Externality -> { t with externality = value } + | Nonmodal Nullability -> { t with nullability = value } +end let transl_modifier_annots annots = let step modifiers_so_far annot = @@ -114,9 +152,7 @@ let transl_modifier_annots annots = if is_dup then raise (Error (annot.loc, Duplicated_axis axis)); Transled_modifiers.set ~axis modifiers_so_far (Some { txt = mode; loc }) in - let empty_modifiers = - Transled_modifiers.Create.f { f = (fun ~axis:_ -> None) } - in + let empty_modifiers = Transled_modifiers.empty in List.fold_left step empty_modifiers annots let transl_mode_annots annots : Alloc.Const.Option.t = @@ -134,16 +170,14 @@ let transl_mode_annots annots : Alloc.Const.Option.t = then raise (Error (annot.loc, Duplicated_axis axis)); Transled_modifiers.set ~axis modifiers_so_far (Some { txt = mode; loc }) in - let empty_modifiers = - Transled_modifiers.Create.f { f = (fun ~axis:_ -> None) } - in + let empty_modifiers = Transled_modifiers.empty in let modes = List.fold_left step empty_modifiers annots in - { areality = Transled_modifier.drop_loc modes.locality; - linearity = Transled_modifier.drop_loc modes.linearity; - uniqueness = Transled_modifier.drop_loc modes.uniqueness; - portability = Transled_modifier.drop_loc modes.portability; - contention = Transled_modifier.drop_loc modes.contention; - yielding = Transled_modifier.drop_loc modes.yielding + { areality = Option.map get_txt modes.locality; + linearity = Option.map get_txt modes.linearity; + uniqueness = Option.map get_txt modes.uniqueness; + portability = Option.map get_txt modes.portability; + contention = Option.map get_txt modes.contention; + yielding = Option.map get_txt modes.yielding } let untransl_mode_annots ~loc (modes : Mode.Alloc.Const.Option.t) = diff --git a/typing/typemode.mli b/typing/typemode.mli index eb8b3cc8cac..baa85bbeee0 100644 --- a/typing/typemode.mli +++ b/typing/typemode.mli @@ -30,13 +30,25 @@ val untransl_modalities : Mode.Modality.Value.Const.t -> Parsetree.modalities -module Transled_modifier : sig - type 'a t = 'a Location.loc option +module Transled_modifiers : sig + type t = + { locality : Mode.Locality.Const.t Location.loc option; + linearity : Mode.Linearity.Const.t Location.loc option; + uniqueness : Mode.Uniqueness.Const.t Location.loc option; + portability : Mode.Portability.Const.t Location.loc option; + contention : Mode.Contention.Const.t Location.loc option; + yielding : Mode.Yielding.Const.t Location.loc option; + externality : Jkind_axis.Externality.t Location.loc option; + nullability : Jkind_axis.Nullability.t Location.loc option + } + + val empty : t + + val get : axis:'a Jkind_axis.Axis.t -> t -> 'a Location.loc option + + val set : axis:'a Jkind_axis.Axis.t -> t -> 'a Location.loc option -> t end -module Transled_modifiers : - module type of Jkind_axis.Axis_collection.Indexed (Transled_modifier) - (** Interpret a list of modifiers. A "modifier" is any keyword coming after a `mod` in a jkind *) val transl_modifier_annots : Parsetree.modes -> Transled_modifiers.t diff --git a/typing/types.ml b/typing/types.ml index 751c7d203b1..60e4a061de6 100644 --- a/typing/types.ml +++ b/typing/types.ml @@ -28,8 +28,57 @@ let is_mutable = function (* Type expressions for the core language *) -module Jkind_mod_bounds = - Jkind_axis.Axis_collection.Indexed (Misc.Stdlib.Monad.Identity) +module Jkind_mod_bounds = struct + type t = { + locality: Mode.Locality.Const.t; + linearity: Mode.Linearity.Const.t; + uniqueness: Mode.Uniqueness.Const.t; + portability: Mode.Portability.Const.t; + contention: Mode.Contention.Const.t; + yielding: Mode.Yielding.Const.t; + externality: Jkind_axis.Externality.t; + nullability: Jkind_axis.Nullability.t; + } + + let[@inline] locality t = t.locality + let[@inline] linearity t = t.linearity + let[@inline] uniqueness t = t.uniqueness + let[@inline] portability t = t.portability + let[@inline] contention t = t.contention + let[@inline] yielding t = t.yielding + let[@inline] externality t = t.externality + let[@inline] nullability t = t.nullability + + let[@inline] create + ~locality + ~linearity + ~uniqueness + ~portability + ~contention + ~yielding + ~externality + ~nullability = + { + locality; + linearity; + uniqueness; + portability; + contention; + yielding; + externality; + nullability; + } + + let[@inline] set_locality locality t = { t with locality } + let[@inline] set_linearity linearity t = { t with linearity } + let[@inline] set_uniqueness uniqueness t = { t with uniqueness } + let[@inline] set_portability portability t = { t with portability } + let[@inline] set_contention contention t = { t with contention } + let[@inline] set_yielding yielding t = { t with yielding } + let[@inline] set_externality externality t = { t with externality } + let[@inline] set_nullability nullability t = { t with nullability } +end + module With_bounds_type_info = struct type t = {relevant_axes : Jkind_axis.Axis_set.t } [@@unboxed] diff --git a/typing/types.mli b/typing/types.mli index 50d1fcd6396..c12c067813e 100644 --- a/typing/types.mli +++ b/typing/types.mli @@ -69,8 +69,39 @@ val is_mutable : mutability -> bool *) (** The mod-bounds of a jkind *) -module Jkind_mod_bounds : - module type of Jkind_axis.Axis_collection.Indexed (Misc.Stdlib.Monad.Identity) +module Jkind_mod_bounds : sig + type t + + val create : + locality:Mode.Locality.Const.t -> + linearity:Mode.Linearity.Const.t -> + uniqueness:Mode.Uniqueness.Const.t -> + portability:Mode.Portability.Const.t -> + contention:Mode.Contention.Const.t -> + yielding:Mode.Yielding.Const.t -> + externality:Jkind_axis.Externality.t -> + nullability:Jkind_axis.Nullability.t -> + t + + val locality : t -> Mode.Locality.Const.t + val linearity : t -> Mode.Linearity.Const.t + val uniqueness : t -> Mode.Uniqueness.Const.t + val portability : t -> Mode.Portability.Const.t + val contention : t -> Mode.Contention.Const.t + val yielding : t -> Mode.Yielding.Const.t + val externality : t -> Jkind_axis.Externality.t + val nullability : t -> Jkind_axis.Nullability.t + + val set_locality : Mode.Locality.Const.t -> t -> t + val set_linearity : Mode.Linearity.Const.t -> t -> t + val set_uniqueness : Mode.Uniqueness.Const.t -> t -> t + val set_portability : Mode.Portability.Const.t -> t -> t + val set_contention : Mode.Contention.Const.t -> t -> t + val set_yielding : Mode.Yielding.Const.t -> t -> t + val set_externality : Jkind_axis.Externality.t -> t -> t + val set_nullability : Jkind_axis.Nullability.t -> t -> t +end + (** Information tracked about an individual type within the with-bounds for a jkind *) module With_bounds_type_info : sig