Skip to content

Commit

Permalink
Define inlined, more-dumb versions of Mod_bounds ops (#3605)
Browse files Browse the repository at this point in the history
* Define inlined, more-dumb versions of Mod_bounds ops

The mod-bounds ops that used polymorphic functions and first-class modules were
not only a little bit over-abstracted for some peoples' taste, they actually had
noticeably worse performance, due to suboptimal inlining behavior (and
allocation!) of the `Accent_lattice` FCM. This changes them to be dumber and
more repetitive, but more direct, which also gets us a few percent performance
win.

* Delete Axis_collection entirely

Let's inhabit one side of this architectural fence entirely, rather than
straddling it.

* define equal functions for lattices

This sadly doesn't show up on a benchmark, but the profiler shows us spending a
fair bit less time here and the generated assembly looks way nicer. Plus it'll
probably vectorize much more nicely, once we have a vectorizer.
  • Loading branch information
glittershark authored Feb 26, 2025
1 parent 5d0cc9e commit 5f8a33b
Show file tree
Hide file tree
Showing 9 changed files with 440 additions and 538 deletions.
319 changes: 192 additions & 127 deletions typing/jkind.ml

Large diffs are not rendered by default.

234 changes: 0 additions & 234 deletions typing/jkind_axis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -172,240 +172,6 @@ module Axis = struct
| Nonmodal Nullability -> false
end

module Axis_collection = struct
module Indexed_gen (T : Misc.T2) = struct
type 'a t_poly =
{ locality : (Mode.Locality.Const.t, 'a) T.t;
linearity : (Mode.Linearity.Const.t, 'a) T.t;
uniqueness : (Mode.Uniqueness.Const.t, 'a) T.t;
portability : (Mode.Portability.Const.t, 'a) T.t;
contention : (Mode.Contention.Const.t, 'a) T.t;
yielding : (Mode.Yielding.Const.t, 'a) T.t;
externality : (Externality.t, 'a) T.t;
nullability : (Nullability.t, 'a) T.t
}

type 'a t = 'a t_poly

let get (type a) ~(axis : a Axis.t) (t : 'b t) : (a, 'b) T.t =
match axis with
| Modal (Comonadic Areality) -> t.locality
| Modal (Comonadic Linearity) -> t.linearity
| Modal (Monadic Uniqueness) -> t.uniqueness
| Modal (Comonadic Portability) -> t.portability
| Modal (Monadic Contention) -> t.contention
| Modal (Comonadic Yielding) -> t.yielding
| Nonmodal Externality -> t.externality
| Nonmodal Nullability -> t.nullability

let set (type a) ~(axis : a Axis.t) (t : 'b t) (value : (a, 'b) T.t) =
match axis with
| Modal (Comonadic Areality) -> { t with locality = value }
| Modal (Comonadic Linearity) -> { t with linearity = value }
| Modal (Monadic Uniqueness) -> { t with uniqueness = value }
| Modal (Comonadic Portability) -> { t with portability = value }
| Modal (Monadic Contention) -> { t with contention = value }
| Modal (Comonadic Yielding) -> { t with yielding = value }
| Nonmodal Externality -> { t with externality = value }
| Nonmodal Nullability -> { t with nullability = value }

(* Since we don't have polymorphic parameters, use a record to pass the
polymorphic function *)
module Create = struct
module Monadic (M : Misc.Stdlib.Monad.S) = struct
type 'a f = { f : 'axis. axis:'axis Axis.t -> ('axis, 'a) T.t M.t }
[@@unboxed]

let[@inline] f { f } =
let open M.Syntax in
let* locality = f ~axis:Axis.(Modal (Comonadic Areality)) in
let* uniqueness = f ~axis:Axis.(Modal (Monadic Uniqueness)) in
let* linearity = f ~axis:Axis.(Modal (Comonadic Linearity)) in
let* contention = f ~axis:Axis.(Modal (Monadic Contention)) in
let* portability = f ~axis:Axis.(Modal (Comonadic Portability)) in
let* yielding = f ~axis:Axis.(Modal (Comonadic Yielding)) in
let* externality = f ~axis:Axis.(Nonmodal Externality) in
let* nullability = f ~axis:Axis.(Nonmodal Nullability) in
M.return
{ locality;
uniqueness;
linearity;
contention;
portability;
yielding;
externality;
nullability
}
end
[@@inline]

module Monadic_identity = Monadic (Misc.Stdlib.Monad.Identity)

type 'a f = 'a Monadic_identity.f

let[@inline] f f = Monadic_identity.f f
end

module Map = struct
module Monadic (M : Misc.Stdlib.Monad.S) = struct
type ('a, 'b) f =
{ f :
'axis. axis:'axis Axis.t -> ('axis, 'a) T.t -> ('axis, 'b) T.t M.t
}
[@@unboxed]

module Create = Create.Monadic (M)

let[@inline] f { f } bounds =
Create.f { f = (fun ~axis -> f ~axis (get ~axis bounds)) }
end
[@@inline]

module Monadic_identity = Monadic (Misc.Stdlib.Monad.Identity)

type ('a, 'b) f = ('a, 'b) Monadic_identity.f

let[@inline] f f bounds = Monadic_identity.f f bounds
end

module Iter = struct
type 'a f = { f : 'axis. axis:'axis Axis.t -> ('axis, 'a) T.t -> unit }
[@@unboxed]

let[@inline] f { f }
{ locality;
linearity;
uniqueness;
portability;
contention;
yielding;
externality;
nullability
} =
f ~axis:Axis.(Modal (Comonadic Areality)) locality;
f ~axis:Axis.(Modal (Monadic Uniqueness)) uniqueness;
f ~axis:Axis.(Modal (Comonadic Linearity)) linearity;
f ~axis:Axis.(Modal (Monadic Contention)) contention;
f ~axis:Axis.(Modal (Comonadic Portability)) portability;
f ~axis:Axis.(Modal (Comonadic Yielding)) yielding;
f ~axis:Axis.(Nonmodal Externality) externality;
f ~axis:Axis.(Nonmodal Nullability) nullability
end

module Map2 = struct
module Monadic (M : Misc.Stdlib.Monad.S) = struct
type ('a, 'b, 'c) f =
{ f :
'axis.
axis:'axis Axis.t ->
('axis, 'a) T.t ->
('axis, 'b) T.t ->
('axis, 'c) T.t M.t
}
[@@unboxed]

module Create = Create.Monadic (M)

let[@inline] f { f } bounds1 bounds2 =
Create.f
{ f = (fun ~axis -> f ~axis (get ~axis bounds1) (get ~axis bounds2))
}
end
[@@inline]

module Monadic_identity = Monadic (Misc.Stdlib.Monad.Identity)

type ('a, 'b, 'c) f = ('a, 'b, 'c) Monadic_identity.f

let[@inline] f f bounds1 bounds2 = Monadic_identity.f f bounds1 bounds2
end

module Fold = struct
type ('a, 'r) f =
{ f : 'axis. axis:'axis Axis.t -> ('axis, 'a) T.t -> 'r }
[@@unboxed]

let[@inline] f { f }
{ locality;
linearity;
uniqueness;
portability;
contention;
yielding;
externality;
nullability
} ~combine =
combine (f ~axis:Axis.(Modal (Comonadic Areality)) locality)
@@ combine (f ~axis:Axis.(Modal (Monadic Uniqueness)) uniqueness)
@@ combine (f ~axis:Axis.(Modal (Comonadic Linearity)) linearity)
@@ combine (f ~axis:Axis.(Modal (Monadic Contention)) contention)
@@ combine (f ~axis:Axis.(Modal (Comonadic Portability)) portability)
@@ combine (f ~axis:Axis.(Modal (Comonadic Yielding)) yielding)
@@ combine (f ~axis:Axis.(Nonmodal Externality) externality)
@@ f ~axis:Axis.(Nonmodal Nullability) nullability
end

module Fold2 = struct
type ('a, 'b, 'r) f =
{ f :
'axis. axis:'axis Axis.t -> ('axis, 'a) T.t -> ('axis, 'b) T.t -> 'r
}
[@@unboxed]

let[@inline] f { f }
{ locality = loc1;
linearity = lin1;
uniqueness = uni1;
portability = por1;
contention = con1;
yielding = yie1;
externality = ext1;
nullability = nul1
}
{ locality = loc2;
linearity = lin2;
uniqueness = uni2;
portability = por2;
contention = con2;
yielding = yie2;
externality = ext2;
nullability = nul2
} ~combine =
combine (f ~axis:Axis.(Modal (Comonadic Areality)) loc1 loc2)
@@ combine (f ~axis:Axis.(Modal (Monadic Uniqueness)) uni1 uni2)
@@ combine (f ~axis:Axis.(Modal (Comonadic Linearity)) lin1 lin2)
@@ combine (f ~axis:Axis.(Modal (Monadic Contention)) con1 con2)
@@ combine (f ~axis:Axis.(Modal (Comonadic Portability)) por1 por2)
@@ combine (f ~axis:Axis.(Modal (Comonadic Yielding)) yie1 yie2)
@@ combine (f ~axis:Axis.(Nonmodal Externality) ext1 ext2)
@@ f ~axis:Axis.(Nonmodal Nullability) nul1 nul2
end
end

module Indexed (T : Misc.T1) = struct
include Indexed_gen (struct
type ('a, 'b) t = 'a T.t
end)

type nonrec t = unit t
end

module Identity = Indexed (Misc.Stdlib.Monad.Identity)

include Indexed_gen (struct
type ('a, 'b) t = 'b
end)

let create ~f = Create.f { f = (fun ~axis -> f ~axis:(Axis.Pack axis)) }

let map ~f t = Map.f { f = (fun ~axis:_ x -> f x) } t

let mapi ~f t = Map.f { f = (fun ~axis x -> f ~axis:(Axis.Pack axis) x) } t

let fold ~f ~combine t =
Fold.f { f = (fun ~axis acc -> f ~axis:(Axis.Pack axis) acc) } t ~combine
end

module Axis_set = struct
(* This could be [bool Axis_collection.t], but instead we represent it as a bitfield for
performance (this matters, since these are hammered on quite a bit during with-bound
Expand Down
142 changes: 0 additions & 142 deletions typing/jkind_axis.mli
Original file line number Diff line number Diff line change
Expand Up @@ -68,148 +68,6 @@ module Axis : sig
val is_modal : _ t -> bool
end

(** A collection with one item for each jkind axis *)
module Axis_collection : sig
module type S_gen := sig
type ('a, 'b) u

(* This is t_poly instead of t because in some instantiations of this signature, u
ignores its second parameter. In order to avoid needed to apply a useless type
parameter for those instantiations, we define [type t = unit t_poly] in them. In
instantiations where the polymorphism is actually used, we define
[type 'a t = 'a t_poly] *)
type 'a t_poly =
{ locality : (Mode.Locality.Const.t, 'a) u;
linearity : (Mode.Linearity.Const.t, 'a) u;
uniqueness : (Mode.Uniqueness.Const.t, 'a) u;
portability : (Mode.Portability.Const.t, 'a) u;
contention : (Mode.Contention.Const.t, 'a) u;
yielding : (Mode.Yielding.Const.t, 'a) u;
externality : (Externality.t, 'a) u;
nullability : (Nullability.t, 'a) u
}

val get : axis:'a Axis.t -> 'b t_poly -> ('a, 'b) u

val set : axis:'a Axis.t -> 'b t_poly -> ('a, 'b) u -> 'b t_poly

(** Create an axis collection by applying the function on each axis *)
module Create : sig
module Monadic (M : Misc.Stdlib.Monad.S) : sig
type 'a f = { f : 'axis. axis:'axis Axis.t -> ('axis, 'a) u M.t }
[@@unboxed]

val f : 'a f -> 'a t_poly M.t
end

(** This record type is used to pass a polymorphic function to [create] *)
type 'a f = 'a Monadic(Misc.Stdlib.Monad.Identity).f

val f : 'a f -> 'a t_poly
end

(** Map an operation over all the bounds *)
module Map : sig
module Monadic (M : Misc.Stdlib.Monad.S) : sig
type ('a, 'b) f =
{ f : 'axis. axis:'axis Axis.t -> ('axis, 'a) u -> ('axis, 'b) u M.t }
[@@unboxed]

val f : ('a, 'b) f -> 'a t_poly -> 'b t_poly M.t
end

type ('a, 'b) f = ('a, 'b) Monadic(Misc.Stdlib.Monad.Identity).f

val f : ('a, 'b) f -> 'a t_poly -> 'b t_poly
end

module Iter : sig
type 'a f = { f : 'axis. axis:'axis Axis.t -> ('axis, 'a) u -> unit }
[@@unboxed]

val f : 'a f -> 'a t_poly -> unit
end

(** Map an operation over two sets of bounds *)
module Map2 : sig
module Monadic (M : Misc.Stdlib.Monad.S) : sig
type ('a, 'b, 'c) f =
{ f :
'axis.
axis:'axis Axis.t ->
('axis, 'a) u ->
('axis, 'b) u ->
('axis, 'c) u M.t
}
[@@unboxed]

val f : ('a, 'b, 'c) f -> 'a t_poly -> 'b t_poly -> 'c t_poly M.t
end

type ('a, 'b, 'c) f = ('a, 'b, 'c) Monadic(Misc.Stdlib.Monad.Identity).f

val f : ('a, 'b, 'c) f -> 'a t_poly -> 'b t_poly -> 'c t_poly
end

(** Fold an operation over the bounds to a summary value *)
module Fold : sig
type ('a, 'r) f = { f : 'axis. axis:'axis Axis.t -> ('axis, 'a) u -> 'r }
[@@unboxed]

(** [combine] should be commutative and associative. *)
val f : ('a, 'r) f -> 'a t_poly -> combine:('r -> 'r -> 'r) -> 'r
end

(** Fold an operation over two sets of bounds to a summary value *)
module Fold2 : sig
type ('a, 'b, 'r) f =
{ f : 'axis. axis:'axis Axis.t -> ('axis, 'a) u -> ('axis, 'b) u -> 'r }
[@@unboxed]

(** [combine] should be commutative and associative. *)
val f :
('a, 'b, 'r) f ->
'a t_poly ->
'b t_poly ->
combine:('r -> 'r -> 'r) ->
'r
end
end

module type S_poly := sig
include S_gen

type 'a t = 'a t_poly
end

module type S_mono := sig
include S_gen

type t = unit t_poly
end

module Indexed_gen (T : Misc.T2) : S_poly with type ('a, 'b) u := ('a, 'b) T.t

module Indexed (T : Misc.T1) : S_mono with type ('a, 'b) u := 'a T.t

module Identity : S_mono with type ('a, 'b) u := 'a

include S_poly with type ('a, 'b) u := 'b

val create : f:(axis:Axis.packed -> 'a) -> 'a t

val get : axis:'ax Axis.t -> 'a t -> 'a

val set : axis:'ax Axis.t -> 'a t -> 'a -> 'a t

val mapi : f:(axis:Axis.packed -> 'a -> 'a) -> 'a t -> 'a t

val map : f:('a -> 'a) -> 'a t -> 'a t

val fold :
f:(axis:Axis.packed -> 'a -> 'r) -> combine:('r -> 'r -> 'r) -> 'a t -> 'r
end

module Axis_set : sig
(** A set of [Axis.t], represented as a bitfield for efficiency. *)
type t [@@immediate]
Expand Down
Loading

0 comments on commit 5f8a33b

Please sign in to comment.