Skip to content

Commit

Permalink
chore: stash
Browse files Browse the repository at this point in the history
  • Loading branch information
bollu committed Oct 30, 2024
1 parent 25fffe0 commit efde767
Show file tree
Hide file tree
Showing 2 changed files with 254 additions and 171 deletions.
188 changes: 129 additions & 59 deletions Tactics/CSE.lean
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ structure CSEConfig where
/-- Whether the tactic should throw an error if no CSEable subterms were found. -/
failIfUnchanged : Bool := true
/-- Number of steps the tactic should spend searching for subterms to gather information. -/
fuelSearch : Nat := 1000
fuelSearch : Nat := 99999
/-- Number of steps the tactic should spend performing subexpression elimination.
It can be useful to have large amounts of fuel for searching, and very little for eliminating,
to search for maximal subterm sharing, and to then eliminiate the most common occurrences.
Expand All @@ -127,13 +127,18 @@ structure ExprData where
size : Nat
deriving Repr

instance : ToMessageData ExprData where
toMessageData d := m!"occs:{d.occs} size:{d.size}"


def ExprData.incrRef (data : ExprData) : ExprData :=
{ data with occs := data.occs + 1 }


structure State where
/--
A mapping from expression to its canonical index.
TODO: replace with discrimination tree to allow sharing in the keys.
-/
canon2data : Std.HashMap Expr ExprData := {}
/--
Expand Down Expand Up @@ -201,51 +206,98 @@ def ExprData.isProfitable? (data : ExprData) : CSEM Bool :=
return data.size > 1 && data.occs >= (← getConfig).minOccsToCSE


partial def CSEM.addOrUpdateData (e : Expr) (size : Nat) : CSEM ExprData := do
let s ← getState
match s.canon2data[e]? with
| .some data => do
let data := data.incrRef
traceLargeMsg m!"updated expr (...) with info ({repr data})" m!"{e}"
setState { s with canon2data := s.canon2data.insert e data }
return data
| .none =>
let data := { occs := 1, size : ExprData }
traceLargeMsg m!"updated expr (...) with info ({repr data})" m!"{e}"
/- Insert the new canonical info. -/
setState { s with canon2data := s.canon2data.insert e data }
return data

mutual

/-
For now, we only visit those expressions that do not create new binders.
forall: we visit non dependent arrows, since they do not add new binders.
A dependent arrow adds a binder, thanks to `(w : Nat) → BitVec w` creating a binder for `w`.
lam: we don't visit them, as their argument creates a binding.
letE: we don't visit them, as they create a let binding.
app, mdata, proj: we visit them, since they don't create binders.
-/


partial def CSEM.tryAddExpr.visitProj (g : MVarId) (e : Expr) : CSEM (Option ExprData) := do
tryAddExpr g e.projExpr!

partial def CSEM.tryAddExpr.visitMData (g : MVarId) (e : Expr) : CSEM (Option ExprData) := do
tryAddExpr g e.consumeMData


partial def CSEM.tryAddExpr.visitForall (g : MVarId) (e : Expr) : CSEM (Option ExprData) := do
withTraceNode "=>" (collapsed := false) do
-- if it's a dependent arrow, let's not process it. If it's a regular arrow, then we process the lhs and the rhs.
let .some (lhs, rhs) := e.arrow?
| do
trace[tactic.cse.summary] "found dependent forall, not recursing into it"
return .none
let lhs? ← tryAddExpr g lhs
let rhs? ← tryAddExpr g rhs
let some lhs := lhs? | return none
let some rhs := rhs? | return none
addOrUpdateData e (lhs.size + rhs.size + 1)

/--
Visit a function application and build an ExprData that corresponds to it.
This adds the ExprData's of all smaller subexpressions, and returns the ExprData corresponding to this expression
-/
partial def CSEM.tryAddExpr.visitApp (g : MVarId) (e : Expr) : CSEM (Option ExprData) := do
withTraceNode "=> ap" (collapsed := false) do
let (fn, args) := (e.getAppFn', e.getAppArgs)
let mut size := 0
if let .some data ← tryAddExpr g fn then
size := size + data.size
else
size := size + 1
let paramInfos := (← getFunInfo fn).paramInfo
for i in [0 : args.size] do
let arg := args[i]!
/- If we have an application, then only add its children that are explicit. -/
let shouldCount := paramInfos[i]!.isExplicit
if shouldCount then
if let .some data ← tryAddExpr g arg then
size := size + data.size
else
size := size + 1
addOrUpdateData e size

/--
The function is partial because of the call to `tryAddExpr` that
Lean does not infer is smaller in `e`.
-/
partial def CSEM.tryAddExpr (e : Expr) : CSEM (Option ExprData) := do
consumeSearchFuel
unless (← hasSearchFuel) do
trace[Tactics.cse.summary] "⏸️ CSE ran out of fuel while looking for subexpressions. Increase `fuelSearch` in CSEConfig."
return .none

let t ← inferType e
-- for now, we ignore function terms and all literals.
let relevant? := !t.isArrow && !t.isSort && !t.isForall && !(← isLitValue e)
withTraceNode m!"({e}):({t}) [relevant? {if relevant? then checkEmoji else crossEmoji}] (unfold for subexpressions...)" do
/-
If we have an application, then only add its children
that are explicit.
-/
let mut size := 1
if e.isApp then
let (fn, args) := (e.getAppFn, e.getAppArgs)
let paramInfos := (← getFunInfo fn).paramInfo
for i in [0 : args.size] do
let arg := args[i]!
let shouldCount := paramInfos[i]!.isExplicit
if shouldCount then
if let .some data ← tryAddExpr arg then
size := size + data.size
-- the current argument itself was irrelevant, so don't bother adding it.
if !relevant? then return .none
let s ← getState
match s.canon2data[e]? with
| .some data => do
let data := data.incrRef
traceLargeMsg m!"updated expr (...) with info ({repr data})" m!"{e}"
setState { s with canon2data := s.canon2data.insert e data }
return .some data
| .none =>
let data := { occs := 1, size : ExprData }
setState {
s with
canon2data := s.canon2data.insert e data,
}
traceLargeMsg m!"Added new expr (...) with info ({repr data})" m!"{e}"
return .some data
partial def CSEM.tryAddExpr (g : MVarId) (e : Expr) : CSEM (Option ExprData) := do
g.withContext do
consumeSearchFuel
if !(← hasSearchFuel) then
trace[Tactics.cse.summary] "⏸️ CSE ran out of fuel while looking for subexpressions. Increase `fuelSearch` in CSEConfig."
return .none

-- for now, we ignore function terms and all literals.
match e with
| .forallE .. => return none -- tryAddExpr.visitForall g e
| .app .. => tryAddExpr.visitApp g e
| .mdata .. => return none -- tryAddExpr.visitMData g e
| .proj .. => return none -- tryAddExpr.visitProj g e
-- Ignore anything that can create binders.
| .lam .. | .letE .. | .lit .. | .const .. | .sort .. | .mvar .. | .fvar .. | .bvar .. => return .none

end

/-- Execute `x` using the main goal local context. -/
def CSEM.withMainContext (x : CSEM α) : CSEM α := do
Expand Down Expand Up @@ -309,37 +361,42 @@ def CSEM.generalize (arg : GeneralizeArg) : CSEM Bool := do
traceLargeMsg m!"{bombEmoji} failed to generalize {hname}" m!"{e.toMessageData}"
return false

def CSEM.cseImpl : CSEM Unit := do
-- This function is very slow to elaborate, why?
def CSEM.cseImpl (hinting : Bool) : CSEM Unit := do
withMainContext do
withTraceNode m!"CSE collecting hypotheses:" do
let _ ← tryAddExpr (← getMainTarget)
withTraceNode m!"🧺 CSE collecting hypotheses:" do
let _ ← tryAddExpr (← getMainGoal) (← getMainTarget)
/- If we should perform CSE on hypotheses as well. -/
if (← getConfig).processHyps = .allHyps then
for hyp in (← getLocalHyps) do
let _ ← tryAddExpr (← inferType hyp)
let _ ← tryAddExpr (← getMainGoal) (← inferType hyp)

let newCanon2Data : Std.HashMap Expr ExprData ←
-- Profitable expressinons to data.
let e2data : Std.HashMap Expr ExprData ←
withTraceNode m!"⏭️ CSE eliminiating unprofitable expressions (#expressions:{(← getState).canon2data.size}):" do
let mut newCanon2Data := {}
let mut e2data := {}
for (e, data) in (← getState).canon2data.toArray.qsort (fun kv kv' => kv.2.occs > kv'.2.occs) do
if !(← data.isProfitable?) then
traceLargeMsg m!"⏭️ Unprofitable {repr data} ." m!"expr: {e}"
else
newCanon2Data := newCanon2Data.insert e data
return newCanon2Data
e2data := e2data.insert e data
return e2data

withTraceNode m!"CSE summary of profitable expressions (#expressions:{(newCanon2Data).size}):" <| do
-- This block seems to be very slow to elaborate. why?
withTraceNode m!"💸 CSE profitable (#expressions:{(e2data).size}):" do
let hintMsg : MessageData := MessageData.nil
let mut i : Nat := 1
/- sort to show most numerous first follows by smallest. -/
for (e, data) in newCanon2Data.toArray.qsort (fun kv kv' => kv.2.occs > kv'.2.occs) do
if !(← data.isProfitable?) then continue
for (e, data) in e2data.toArray.qsort (fun kv kv' => kv.2.occs > kv'.2.occs) do
traceLargeMsg m!"{i}) {repr data}" m!"{e}"
i := i + 1
/- We're providing user hints, so we print the information about profitable expressions directly as a user message -/
if hinting then
logInfo m!"{i}) {toMessageData data} {e}"


withTraceNode m!"CSE rewriting (#expressions:{newCanon2Data.size}):" do
withTraceNode m!"▶️ CSE rewriting (#expressions:{e2data.size}):" do
let mut madeProgress := false
for (e, _data) in newCanon2Data.toArray.qsort (fun kv kv' => kv.2.size > kv'.2.size) do
for (e, _data) in e2data.toArray.qsort (fun kv kv' => kv.2.size > kv'.2.size || (kv.2.size == kv'.2.size && kv.1.hash < kv'.1.hash)) do
let generalizeArg ← planCSE e
madeProgress := madeProgress || (← generalize generalizeArg)
if !madeProgress && !(← getConfig).failIfUnchanged
Expand All @@ -349,10 +406,10 @@ def CSEM.cseImpl : CSEM Unit := do
open Lean Elab Tactic Parser.Tactic

/-- The `cse` tactic, for performing common subexpression elimination of goal states. -/
def cseTactic (cfg : CSEConfig) : TacticM Unit := CSEM.cseImpl |>.run cfg
def cseTactic (cfg : CSEConfig) : TacticM Unit := CSEM.cseImpl false |>.run cfg

/-- The `cse` tactic with the default configuration. -/
def cseTacticDefault : TacticM Unit := CSEM.cseImpl |>.run {}
/-- The `cse` tactic, for performing common subexpression elimination of goal states. -/
def cseHintsTactic (cfg : CSEConfig) : TacticM Unit := CSEM.cseImpl true |>.run cfg

end Tactic.CSE

Expand All @@ -372,3 +429,16 @@ def evalCse : Tactic := fun
let cfg ← elabCSEConfig (mkOptionalNode cfg)
Tactic.CSE.cseTactic cfg
| _ => throwUnsupportedSyntax

/-
common subexpression elimination, but with user guidance.
Provide hints about which expressions can be CSEd.
-/
syntax (name := cseHints) "cse?" (Lean.Parser.Tactic.config)? : tactic

@[tactic cseHints]
def evalCseHints : Tactic := fun
| `(tactic| cse? $[$cfg]?) => do
let cfg ← elabCSEConfig (mkOptionalNode cfg)
Tactic.CSE.cseHintsTactic cfg
| _ => throwUnsupportedSyntax
Loading

0 comments on commit efde767

Please sign in to comment.