Skip to content

Commit

Permalink
Fix DCE/Subst01 to work under lambdas
Browse files Browse the repository at this point in the history
The naive encoding of PHOAS passes that need to produce both
[expr]-like output and data-like output simultaneously involves
exponential blowup.

This commit adds caching of results (and/or intermediates) of a
data-producing PHOAS pass in a tree structure that mimics the
PHOAS expression so that a subsequent pass can consume this tree
and a PHOAS expression to produce a new expression.

More concretely, suppose we are trying to write a pass that is
`expr var1 * expr var2 -> A * expr var3`. We can define an
`expr`-like-tree-structure that (a) doesn't use higher-order
things for `Abs` nodes, and (b) stores `A` at every node. Then we
can write a pass that is `expr var1 * expr var2 -> A * tree-of-A`
and then `expr var1 * expr var2 * tree-of-A -> expr var3` such
that we incur only linear overhead.

See also mit-plv#1761 and
mit-plv#1604 (comment).

Fixes mit-plv#1604
  • Loading branch information
JasonGross committed Jan 22, 2024
1 parent 16b3666 commit 7df5f3c
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 104 deletions.
10 changes: 1 addition & 9 deletions src/BoundsPipeline.v
Original file line number Diff line number Diff line change
Expand Up @@ -660,12 +660,6 @@ Module Pipeline.
(E : Expr t)
: DebugM (Expr t)
:= (E <- DoRewrite E;
(* Note that DCE evaluates the expr with two different [var]
arguments, and so results in a pipeline that is 2x slower
unless we pass through a uniformly concrete [var] type
first *)
dlet_nd e := ToFlat E in
let E := FromFlat e in
E <- if with_subst01 return DebugM (Expr t)
then wrap_debug_rewrite ("subst01 for " ++ descr) (Subst01.Subst01 ident.is_comment) E
else if with_dead_code_elimination return DebugM (Expr t)
Expand All @@ -675,8 +669,6 @@ Module Pipeline.
then wrap_debug_rewrite ("LetBindReturn for " ++ descr) (UnderLets.LetBindReturn (@ident.is_var_like)) E
else Debug.ret E;
E <- DoRewrite E; (* after inlining, see if any new rewrite redexes are available *)
dlet_nd e := ToFlat E in
let E := FromFlat e in
E <- if with_dead_code_elimination
then wrap_debug_rewrite ("DCE after " ++ descr) (DeadCodeElimination.EliminateDead ident.is_comment) E
else Debug.ret E;
Expand Down Expand Up @@ -1150,7 +1142,7 @@ Module Pipeline.
first [ progress destruct_head'_and
| progress cbv [Classes.base Classes.ident Classes.ident_interp Classes.base_interp Classes.exprInfo] in *
| progress intros
| progress rewrite_strat repeat topdown hints interp
| progress rewrite_strat repeat topdown choice (hints interp_extra) (hints interp)
| solve [ typeclasses eauto with nocore interp_extra wf_extra ]
| solve [ typeclasses eauto ]
| break_innermost_match_step
Expand Down
70 changes: 70 additions & 0 deletions src/Language/TreeCaching.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
(** * Tree Caching for PHOAS Expressions *)
(** The naive encoding of PHOAS passes that need to produce both
[expr]-like output and data-like output simultaneously involves
exponential blowup.
This file allows caching of results (and/or intermediates) of a
data-producing PHOAS pass in a tree structure that mimics the
PHOAS expression so that a subsequent pass can consume this tree
and a PHOAS expression to produce a new expression.
More concretely, suppose we are trying to write a pass that is
[expr var1 * expr var2 -> A * expr var3]. We can define an
[expr]-like-tree-structure that (a) doesn't use higher-order
things for [Abs] nodes, and (b) stores [A] at every node. Then we
can write a pass that is [expr var1 * expr var2 -> A * tree-of-A]
and then [expr var1 * expr var2 * tree-of-A -> expr var3] such
that we incur only linear overhead.
See also
%\href{https://github.com/mit-plv/fiat-crypto/issues/1604#issuecomment-1553341559}{mit-plv/fiat-crypto\#1604 with option (2)}%
#<a href=https://github.com/mit-plv/fiat-crypto/issues/1604##issuecomment-1553341559">mit-plv/fiat-crypto##1604 with option (2)</a>#
and
%\href{https://github.com/mit-plv/fiat-crypto/issues/1761}{mit-plv/fiat-crypto\#1761}%
#<a href=https://github.com/mit-plv/fiat-crypto/issues/1761#">mit-plv/fiat-crypto##1761</a>#. *)

Require Import Rewriter.Language.Language.

Module Compilers.
Export Language.Compilers.
Local Set Boolean Equality Schemes.
Local Set Decidable Equality Schemes.

Module tree_nd.
Section with_result.
Context {base_type : Type}.
Local Notation type := (type base_type).
Context {ident : type -> Type}
{result : Type}.
Local Notation expr := (@expr.expr base_type ident).

Inductive tree : Type :=
| Ident (r : result) : tree
| Var (r : result) : tree
| Abs (r : result) (f : option tree) : tree
| App (r : result) (f : option tree) (x : option tree) : tree
| LetIn (r : result) (x : option tree) (f : option tree) : tree
.
End with_result.
Global Arguments tree result : clear implicits, assert.
End tree_nd.

Module tree.
Section with_result.
Context {base_type : Type}.
Local Notation type := (type base_type).
Context {ident : type -> Type}
{result : type -> Type}.
Local Notation expr := (@expr.expr base_type ident).

Inductive tree : type -> Type :=
| Ident {t} (r : result t) : tree t
| Var {t} (r : result t) : tree t
| Abs {s d} (r : result (s -> d)) (f : option (tree d)) : tree (s -> d)
| App {s d} (r : result d) (f : option (tree (s -> d))) (x : option (tree s)) : tree d
| LetIn {A B} (r : result B) (x : option (tree A)) (f : option (tree B)) : tree B
.
End with_result.
Global Arguments tree {base_type} {result} t, {base_type} result t : assert.
End tree.
End Compilers.
147 changes: 92 additions & 55 deletions src/MiscCompilerPasses.v
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ Require Import Coq.MSets.MSetPositive.
Require Import Coq.FSets.FMapPositive.
Require Import Crypto.Util.ListUtil Coq.Lists.List.
Require Import Rewriter.Language.Language.
Require Import Crypto.Util.LetIn.
Require Import Crypto.Util.Notations.
Require Import Crypto.Language.TreeCaching.
Import ListNotations. Local Open Scope Z_scope.

Module Compilers.
Export Language.Compilers.
Export Language.TreeCaching.Compilers.
Import invert_expr.

Module Subst01.
Expand All @@ -33,6 +36,8 @@ Module Compilers.
(** some identifiers, like [comment], might always be live *)
(is_ident_always_live : forall t, ident t -> bool).
Local Notation expr := (@expr.expr base_type ident).
(* [option t] is "is the let-in here live?", meaningless elsewhere; the thunk is for debugging *)
Local Notation tree := (@tree_nd.tree (option t * (unit -> positive * list (positive * t)))).
(** N.B. This does not work well when let-binders are not at top-level *)
Fixpoint contains_always_live_ident {var} (dummy : forall t, var t) {t} (e : @expr var t)
: bool
Expand All @@ -46,28 +51,39 @@ Module Compilers.
| expr.LetIn tx tC ex eC
=> contains_always_live_ident dummy ex || contains_always_live_ident dummy (eC (dummy _))
end%bool.
Definition meaningless : option t * (unit -> positive * list (positive * t)) := (None, (fun 'tt => (1%positive, []%list))).
Global Opaque meaningless.
Fixpoint compute_live_counts' {t} (e : @expr (fun _ => positive) t) (cur_idx : positive) (live : PositiveMap.t _)
: positive * PositiveMap.t _
: positive * PositiveMap.t _ * option tree
:= match e with
| expr.Var t v => (cur_idx, PositiveMap_incr v live)
| expr.Ident t idc => (cur_idx, live)
| expr.Var t v
=> let '(idx, live) := (cur_idx, PositiveMap_incr v live) in
(idx, live, Some (tree_nd.Var meaningless))
| expr.Ident t idc
=> let '(idx, live) := (cur_idx, live) in
(idx, live, Some (tree_nd.Ident meaningless))
| expr.App s d f x
=> let '(idx, live) := @compute_live_counts' _ f cur_idx live in
let '(idx, live) := @compute_live_counts' _ x idx live in
(idx, live)
=> let '(idx, live, f_tree) := @compute_live_counts' _ f cur_idx live in
let '(idx, live, x_tree) := @compute_live_counts' _ x idx live in
(idx, live, Some (tree_nd.App meaningless f_tree x_tree))
| expr.Abs s d f
=> let '(idx, live) := @compute_live_counts' _ (f cur_idx) (Pos.succ cur_idx) live in
(cur_idx, live)
=> let '(idx, live, f_tree) := @compute_live_counts' _ (f cur_idx) (Pos.succ cur_idx) live in
(idx, live, Some (tree_nd.Abs meaningless f_tree))
| expr.LetIn tx tC ex eC
=> let '(idx, live) := @compute_live_counts' tC (eC cur_idx) (Pos.succ cur_idx) live in
=> let '(idx, live, C_tree) := @compute_live_counts' tC (eC cur_idx) (Pos.succ cur_idx) live in
let live := if contains_always_live_ident (fun _ => cur_idx (* dummy *)) ex
then PositiveMap_incr_always_live cur_idx live
else live in
if PositiveMap.mem cur_idx live
then @compute_live_counts' tx ex idx live
else (idx, live)
let debug_info := fun 'tt => (Pos.succ cur_idx, PositiveMap.elements live) in
match PositiveMap.find cur_idx live with
| Some x_count
=> let '(x_idx, x_live, x_tree) := @compute_live_counts' tx ex idx live in
(x_idx, x_live, Some (tree_nd.LetIn (Some x_count, debug_info) x_tree C_tree))
| None
=> (idx, live, Some (tree_nd.LetIn (None, debug_info) None C_tree))
end
end%bool.
Definition compute_live_counts {t} e : PositiveMap.t _ := snd (@compute_live_counts' t e 1 (PositiveMap.empty _)).
Definition compute_live_counts {t} e : option tree := snd (@compute_live_counts' t e 1 (PositiveMap.empty _)).
Definition ComputeLiveCounts {t} (e : expr.Expr t) := compute_live_counts (e _).

Section with_var.
Expand All @@ -79,36 +95,61 @@ Module Compilers.
in extraction *)
Context (doing_subst_debug : forall T1 T2, T1 -> (unit -> T2) -> T1)
{var : type -> Type}
(should_subst : t -> bool)
(live : PositiveMap.t t).
Fixpoint subst0n' {t} (e : @expr (@expr var) t) (cur_idx : positive)
: positive * @expr var t
(should_subst : t -> bool).
(** When [live] is [None], we don't inline anything, just
dropping [var]. This is required for preventing blowup
in inlining lets in unused [LetIn]-bound expressions.
*)
Fixpoint subst0n (live : option tree) {t} (e : @expr (@expr var) t)
: @expr var t
:= match e with
| expr.Var t v => (cur_idx, v)
| expr.Ident t idc => (cur_idx, expr.Ident idc)
| expr.Var t v => v
| expr.Ident t idc => expr.Ident idc
| expr.App s d f x
=> let '(idx, f') := @subst0n' _ f cur_idx in
let '(idx, x') := @subst0n' _ x idx in
(idx, expr.App f' x')
=> let '(f_live, x_live)
:= match live with
| Some (tree_nd.App _ f_live x_live) => (f_live, x_live)
| _ => (None, None)
end%core in
let f' := @subst0n f_live _ f in
let x' := @subst0n x_live _ x in
expr.App f' x'
| expr.Abs s d f
=> (cur_idx, expr.Abs (fun v => snd (@subst0n' _ (f (expr.Var v)) (Pos.succ cur_idx))))
=> let f_tree
:= match live with
| Some (tree_nd.Abs _ f_tree) => f_tree
| _ => None
end in
expr.Abs (fun v => @subst0n f_tree _ (f (expr.Var v)))
| expr.LetIn tx tC ex eC
=> let '(idx, ex') := @subst0n' tx ex cur_idx in
let eC' := fun v => snd (@subst0n' tC (eC v) (Pos.succ cur_idx)) in
if match PositiveMap.find cur_idx live with
| Some n => should_subst n
| None => true
end
then (Pos.succ cur_idx, eC' (doing_subst_debug _ _ ex' (fun 'tt => (Pos.succ cur_idx, PositiveMap.elements live))))
else (Pos.succ cur_idx, expr.LetIn ex' (fun v => eC' (expr.Var v)))
=> match live with
| Some (tree_nd.LetIn (x_count, debug_info) x_tree C_tree)
=> let ex' := @subst0n x_tree tx ex in
let eC' := fun v => @subst0n C_tree tC (eC v) in
if match x_count with
| Some n => should_subst n
| None => true
end
then eC' (doing_subst_debug _ _ ex' debug_info)
else expr.LetIn ex' (fun v => eC' (expr.Var v))
| _
=> let ex' := @subst0n None tx ex in
let eC' := fun v => @subst0n None tC (eC v) in
expr.LetIn ex' (fun v => eC' (expr.Var v))
end
end.

Definition subst0n {t} e : expr t
:= snd (@subst0n' t e 1).
End with_var.

Definition Subst0n (doing_subst_debug : forall T1 T2, T1 -> (unit -> T2) -> T1) (should_subst : t -> bool) {t} (e : expr.Expr t) : expr.Expr t
:= fun var => subst0n doing_subst_debug should_subst (ComputeLiveCounts e) (e _).
Section with_transport.
Context {try_make_transport_base_type_cps : @type.try_make_transport_cpsT base_type}
{exprDefault : forall var, @DefaultValue.type.base.DefaultT type (@expr var)}.
(** We pass through [Flat] to ensure that the passed in
[Expr] only gets invoked at a single [var] type *)
Definition Subst0n (doing_subst_debug : forall T1 T2, T1 -> (unit -> T2) -> T1) (should_subst : t -> bool) {t} (E : expr.Expr t) : expr.Expr t
:= dlet_nd e := GeneralizeVar.ToFlat E in
let E := GeneralizeVar.FromFlat e in
fun var => subst0n doing_subst_debug should_subst (ComputeLiveCounts E) (E _).
End with_transport.
End with_ident.
End with_counter.

Expand All @@ -122,34 +163,30 @@ Module Compilers.
| more => false
end.

Definition Subst01 {base_type ident} (is_ident_always_live : forall t, ident t -> bool) {t} (e : expr.Expr t) : expr.Expr t
:= @Subst0n _ one incr (fun _ => more) base_type ident is_ident_always_live (fun _ _ x _ => x) should_subst t e.
Definition Subst01
{base_type ident}
{try_make_transport_base_type_cps : @type.try_make_transport_cpsT base_type}
{exprDefault : forall var, @DefaultValue.type.base.DefaultT _ _}
(is_ident_always_live : forall t, ident t -> bool)
{t} (e : expr.Expr t) : expr.Expr t
:= @Subst0n _ one incr (fun _ => more) base_type ident is_ident_always_live try_make_transport_base_type_cps exprDefault (fun _ _ x _ => x) should_subst t e.
End for_01.
End Subst01.

Module DeadCodeElimination.
Section with_ident.
Context {base_type : Type}.
Local Notation type := (type.type base_type).
Context {ident : type -> Type}
(is_ident_always_live : forall t, ident t -> bool).
Local Notation expr := (@expr.expr base_type ident).

Definition OUGHT_TO_BE_UNUSED {T1 T2} (v : T1) (v' : T2) := v.
Global Opaque OUGHT_TO_BE_UNUSED.

Definition ComputeLive {t} (e : expr.Expr t) : PositiveMap.t unit
:= @Subst01.ComputeLiveCounts unit tt (fun _ => tt) (fun _ => tt) base_type ident is_ident_always_live _ e.
Definition is_live (map : PositiveMap.t unit) (idx : positive) : bool
:= match PositiveMap.find idx map with
| Some tt => true
| None => false
end.
Definition is_dead (map : PositiveMap.t unit) (idx : positive) : bool
:= negb (is_live map idx).

Definition EliminateDead {t} (e : expr.Expr t) : expr.Expr t
:= @Subst01.Subst0n unit tt (fun _ => tt) (fun _ => tt) base_type ident is_ident_always_live (fun T1 T2 => OUGHT_TO_BE_UNUSED) (fun _ => false) t e.
Definition EliminateDead
{ident : type -> Type}
{try_make_transport_base_type_cps : @type.try_make_transport_cpsT base_type}
{exprDefault : forall var, @DefaultValue.type.base.DefaultT _ _}
(is_ident_always_live : forall t, ident t -> bool)
{t} (e : expr.Expr t)
: expr.Expr t
:= @Subst01.Subst0n unit tt (fun _ => tt) (fun _ => tt) base_type ident is_ident_always_live try_make_transport_base_type_cps exprDefault (fun T1 T2 => OUGHT_TO_BE_UNUSED) (fun _ => false) t e.
End with_ident.
End DeadCodeElimination.
End Compilers.
Loading

0 comments on commit 7df5f3c

Please sign in to comment.