Skip to content

Commit

Permalink
feat: use scoped trace nodes in linarith (#19855)
Browse files Browse the repository at this point in the history
Inspired by hacking done with @robertylewis and @hrmacbeth which resulted in #19771.

The effect is that the traces messages are now hierarchical; though it's easy not to notice in VSCode without a better version of leanprover/lean4#6345.

See https://profiler.firefox.com/public/smkc5ffh9318w177gps2x9e5b6wy117s6f18e6g/flame-graph/?globalTrackOrder=0&thread=0&transforms=ff-2659&v=10 for an example output produced with
```bash
lake lean MathlibTest/linarith.lean -- \
  -Dtrace.profiler=true \
  -Dtrace.profiler.threshold=1 \
  -Dtrace.profiler.output.pp=true \
  -Dtrace.profiler.output=linarith-profile.json
```

Some inconclusive discussion about best practices for `withTraceNode` is [on Zulip here](https://leanprover.zulipchat.com/#narrow/channel/270676-lean4/topic/Using.20withTraceNode/near/489198580).



Co-authored-by: Eric Wieser <[email protected]>
  • Loading branch information
eric-wieser and eric-wieser committed Jan 11, 2025
1 parent c0cb03a commit 3a798ee
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 74 deletions.
46 changes: 26 additions & 20 deletions Mathlib/Tactic/Linarith/Datatypes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,17 @@ initialize registerTraceClass `linarith.detail

namespace Linarith

/-- A shorthand for getting the types of a list of proofs terms, to trace. -/
def linarithGetProofsMessage (l : List Expr) : MetaM MessageData := do
return m!"{← l.mapM fun e => do instantiateMVars (← inferType e)}"

/--
A shorthand for tracing the types of a list of proof terms
when the `trace.linarith` option is set to true.
-/
def linarithTraceProofs {α} [ToMessageData α] (s : α) (l : List Expr) : MetaM Unit := do
trace[linarith] "{s}"
trace[linarith] (← l.mapM fun e => do instantiateMVars (← inferType e))
if ← isTracingEnabledFor `linarith then
addRawTrace <| .trace { cls := `linarith } (toMessageData s) #[← linarithGetProofsMessage l]

/-! ### Linear expressions -/

Expand Down Expand Up @@ -167,15 +171,20 @@ instance Comp.ToFormat : ToFormat Comp :=

/-! ### Control -/

/-- Metadata about preprocessors, for trace output. -/
structure PreprocessorBase : Type where
/-- The name of the preprocessor, populated automatically, to create linkable trace messages. -/
name : Name := by exact decl_name%
/-- The description of the preprocessor. -/
description : String

/--
A preprocessor transforms a proof of a proposition into a proof of a different proposition.
The return type is `List Expr`, since some preprocessing steps may create multiple new hypotheses,
and some may remove a hypothesis from the list.
A "no-op" preprocessor should return its input as a singleton list.
-/
structure Preprocessor : Type where
/-- The name of the preprocessor, used in trace output. -/
name : String
structure Preprocessor extends PreprocessorBase : Type where
/-- Replace a hypothesis by a list of hypotheses. These expressions are the proof terms. -/
transform : Expr → MetaM (List Expr)

Expand All @@ -184,9 +193,7 @@ Some preprocessors need to examine the full list of hypotheses instead of workin
As with `Preprocessor`, the input to a `GlobalPreprocessor` is replaced by, not added to, its
output.
-/
structure GlobalPreprocessor : Type where
/-- The name of the global preprocessor, used in trace output. -/
name : String
structure GlobalPreprocessor extends PreprocessorBase : Type where
/-- Replace the collection of all hypotheses with new hypotheses.
These expressions are proof terms. -/
transform : List Expr → MetaM (List Expr)
Expand All @@ -206,9 +213,7 @@ Each branch is independent, so hypotheses that appear in multiple branches shoul
The preprocessor is responsible for making sure that each branch contains the correct goal
metavariable.
-/
structure GlobalBranchingPreprocessor : Type where
/-- The name of the global branching preprocessor, used in trace output. -/
name : String
structure GlobalBranchingPreprocessor extends PreprocessorBase : Type where
/-- Given a goal, and a list of hypotheses,
produce a list of pairs (consisting of a goal and list of hypotheses). -/
transform : MVarId → List Expr → MetaM (List Branch)
Expand All @@ -217,14 +222,14 @@ structure GlobalBranchingPreprocessor : Type where
A `Preprocessor` lifts to a `GlobalPreprocessor` by folding it over the input list.
-/
def Preprocessor.globalize (pp : Preprocessor) : GlobalPreprocessor where
name := pp.name
__ := pp
transform := List.foldrM (fun e ret => do return (← pp.transform e) ++ ret) []

/--
A `GlobalPreprocessor` lifts to a `GlobalBranchingPreprocessor` by producing only one branch.
-/
def GlobalPreprocessor.branching (pp : GlobalPreprocessor) : GlobalBranchingPreprocessor where
name := pp.name
__ := pp
transform := fun g l => do return [⟨g, ← pp.transform l⟩]

/--
Expand All @@ -233,13 +238,14 @@ tracing the result if `trace.linarith` is on.
-/
def GlobalBranchingPreprocessor.process (pp : GlobalBranchingPreprocessor)
(g : MVarId) (l : List Expr) : MetaM (List Branch) := g.withContext do
let branches ← pp.transform g l
if branches.length > 1 then
trace[linarith] "Preprocessing: {pp.name} has branched, with branches:"
for ⟨goal, hyps⟩ in branches do
goal.withContext do
linarithTraceProofs m!"Preprocessing: {pp.name}" hyps
return branches
withTraceNode `linarith (fun e =>
return m!"{exceptEmoji e} {.ofConstName pp.name}: {pp.description}") do
let branches ← pp.transform g l
if branches.length > 1 then
trace[linarith] "Preprocessing: {pp.name} has branched, with branches:"
for ⟨goal, hyps⟩ in branches do
trace[linarith] (← goal.withContext <| linarithGetProofsMessage hyps)
return branches

instance PreprocessorToGlobalBranchingPreprocessor :
Coe Preprocessor GlobalBranchingPreprocessor :=
Expand Down
26 changes: 17 additions & 9 deletions Mathlib/Tactic/Linarith/Frontend.lean
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,14 @@ Given a list `ls` of lists of proofs of comparisons, `findLinarithContradiction
prove `False` by calling `linarith` on each list in succession. It will stop at the first proof of
`False`, and fail if no contradiction is found with any list.
-/
def findLinarithContradiction (cfg : LinarithConfig) (g : MVarId) (ls : List (List Expr)) :
def findLinarithContradiction (cfg : LinarithConfig) (g : MVarId) (ls : List (Expr × List Expr)) :
MetaM Expr :=
try
ls.firstM (fun L => proveFalseByLinarith cfg.transparency cfg.oracle cfg.discharger g L)
ls.firstM (fun ⟨α, L⟩ =>
withTraceNode `linarith (return m!"{exceptEmoji ·} running on type {α}") <|
proveFalseByLinarith cfg.transparency cfg.oracle cfg.discharger g L)
catch e => throwError "linarith failed to find a contradiction\n{g}\n{e.toMessageData}"


/--
Given a list `hyps` of proofs of comparisons, `runLinarith cfg hyps prefType`
preprocesses `hyps` according to the list of preprocessors in `cfg`.
Expand All @@ -272,13 +273,20 @@ def runLinarith (cfg : LinarithConfig) (prefType : Option Expr) (g : MVarId)
(hyps : List Expr) : MetaM Unit := do
let singleProcess (g : MVarId) (hyps : List Expr) : MetaM Expr := g.withContext do
linarithTraceProofs s!"after preprocessing, linarith has {hyps.length} facts:" hyps
let hyp_set ← partitionByType hyps
let mut hyp_set ← partitionByType hyps
trace[linarith] "hypotheses appear in {hyp_set.size} different types"
-- If we have a preferred type, strip it from `hyp_set` and prepare a handler with a custom
-- trace message
let pref : MetaM _ ← do
if let some t := prefType then
let (i, vs) ← hyp_set.find t
proveFalseByLinarith cfg.transparency cfg.oracle cfg.discharger g vs <|>
findLinarithContradiction cfg g ((hyp_set.eraseIdxIfInBounds i).toList.map (·.2))
else findLinarithContradiction cfg g (hyp_set.toList.map (·.2))
hyp_set := hyp_set.eraseIdxIfInBounds i
pure <|
withTraceNode `linarith (return m!"{exceptEmoji ·} running on preferred type {t}") <|
proveFalseByLinarith cfg.transparency cfg.oracle cfg.discharger g vs
else
pure failure
pref <|> findLinarithContradiction cfg g hyp_set.toList
let mut preprocessors := cfg.preprocessors
if cfg.splitNe then
preprocessors := Linarith.removeNe :: preprocessors
Expand Down Expand Up @@ -318,8 +326,8 @@ partial def linarith (only_on : Bool) (hyps : List Expr) (cfg : LinarithConfig :
if (← whnfR (← instantiateMVars (← g.getType))).isEq then
trace[linarith] "target is an equality: splitting"
if let some [g₁, g₂] ← try? (g.apply (← mkConst' ``eq_of_not_lt_of_not_gt)) then
linarith only_on hyps cfg g₁
linarith only_on hyps cfg g₂
withTraceNode `linarith (return m!"{exceptEmoji ·} proving ≥") <| linarith only_on hyps cfg g₁
withTraceNode `linarith (return m!"{exceptEmoji ·} proving ≤") <| linarith only_on hyps cfg g₂
return

/- If we are proving a comparison goal (and not just `False`), we consider the type of the
Expand Down
34 changes: 19 additions & 15 deletions Mathlib/Tactic/Linarith/Preprocessing.lean
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ open Batteries (RBSet)

/-- Processor that recursively replaces `P ∧ Q` hypotheses with the pair `P` and `Q`. -/
partial def splitConjunctions : Preprocessor where
name := "split conjunctions"
description := "split conjunctions"
transform := aux
where
/-- Implementation of the `splitConjunctions` preprocessor. -/
Expand All @@ -54,7 +54,7 @@ where
Removes any expressions that are not proofs of inequalities, equalities, or negations thereof.
-/
partial def filterComparisons : Preprocessor where
name := "filter terms that are not proofs of comparisons"
description := "filter terms that are not proofs of comparisons"
transform h := do
let tp ← instantiateMVars (← inferType h)
try
Expand All @@ -80,7 +80,7 @@ Replaces proofs of negations of comparisons with proofs of the reversed comparis
For example, a proof of `¬ a < b` will become a proof of `a ≥ b`.
-/
def removeNegations : Preprocessor where
name := "replace negations of comparisons"
description := "replace negations of comparisons"
transform h := do
let t : Q(Prop) ← whnfR (← inferType h)
match t with
Expand Down Expand Up @@ -147,7 +147,7 @@ It also adds the facts that the integers involved are nonnegative.
To avoid adding the same nonnegativity facts many times, it is a global preprocessor.
-/
def natToInt : GlobalBranchingPreprocessor where
name := "move nats to ints"
description := "move nats to ints"
transform g l := do
let l ← l.mapM fun h => do
let t ← whnfR (← instantiateMVars (← inferType h))
Expand Down Expand Up @@ -192,7 +192,7 @@ def mkNonstrictIntProof (pf : Expr) : MetaM (Option Expr) := do
/-- `strengthenStrictInt h` turns a proof `h` of a strict integer inequality `t1 < t2`
into a proof of `t1 ≤ t2 + 1`. -/
def strengthenStrictInt : Preprocessor where
name := "strengthen strict inequalities over int"
description := "strengthen strict inequalities over int"
transform h := return [(← mkNonstrictIntProof h).getD h]

end strengthenStrictInt
Expand All @@ -214,7 +214,7 @@ partial def rearrangeComparison (e : Expr) : MetaM (Option Expr) := do
and turns it into a proof of a comparison `_ R 0`, where `R ∈ {=, ≤, <}`.
-/
def compWithZero : Preprocessor where
name := "make comparisons with zero"
description := "make comparisons with zero"
transform e := return (← rearrangeComparison e).toList

end compWithZero
Expand Down Expand Up @@ -247,7 +247,7 @@ def normalizeDenominatorsLHS (h lhs : Expr) : MetaM Expr := do
it tries to scale `t` to cancel out division by numerals.
-/
def cancelDenoms : Preprocessor where
name := "cancel denominators"
description := "cancel denominators"
transform := fun pf => (do
let (_, lhs) ← parseCompAndExpr (← inferType pf)
guard <| lhs.containsConst <| fun n =>
Expand Down Expand Up @@ -288,7 +288,8 @@ partial def findSquares (s : RBSet (Nat × Bool) lexOrd.compare) (e : Expr) :
| _ => e.foldlM findSquares s

/-- Get proofs of `-x^2 ≤ 0` and `-(x*x) ≤ 0`, when those terms appear in `ls` -/
private def nlinarithGetSquareProofs (ls : List Expr) : MetaM (List Expr) := do
private def nlinarithGetSquareProofs (ls : List Expr) : MetaM (List Expr) :=
withTraceNode `linarith (return m!"{exceptEmoji ·} finding squares") do
-- find the squares in `AtomM` to ensure deterministic behavior
let s ← AtomM.run .reducible do
let si ← ls.foldrM (fun h s' => do findSquares s' (← instantiateMVars (← inferType h)))
Expand All @@ -297,8 +298,7 @@ private def nlinarithGetSquareProofs (ls : List Expr) : MetaM (List Expr) := do
let new_es ← s.filterMapM fun (e, is_sq) =>
observing? <| mkAppM (if is_sq then ``sq_nonneg else ``mul_self_nonneg) #[e]
let new_es ← compWithZero.globalize.transform new_es
trace[linarith] "nlinarith preprocessing found squares"
trace[linarith] "{s}"
trace[linarith] "found:{indentD <| toMessageData s}"
linarithTraceProofs "so we added proofs" new_es
return new_es

Expand All @@ -308,7 +308,8 @@ Get proofs for products of inequalities from `ls`.
Note that the length of the resulting list is proportional to `ls.length^2`, which can make a large
amount of work for the linarith oracle.
-/
private def nlinarithGetProductsProofs (ls : List Expr) : MetaM (List Expr) := do
private def nlinarithGetProductsProofs (ls : List Expr) : MetaM (List Expr) :=
withTraceNode `linarith (return m!"{exceptEmoji ·} adding product terms") do
let with_comps ← ls.mapM (fun e => do
let tp ← inferType e
try
Expand Down Expand Up @@ -341,7 +342,7 @@ private def nlinarithGetProductsProofs (ls : List Expr) : MetaM (List Expr) := d
This preprocessor is typically run last, after all inputs have been canonized.
-/
def nlinarithExtras : GlobalPreprocessor where
name := "nonlinear arithmetic extras"
description := "nonlinear arithmetic extras"
transform ls := do
let new_es ← nlinarithGetSquareProofs ls
let products ← nlinarithGetProductsProofs (new_es ++ ls)
Expand Down Expand Up @@ -374,7 +375,7 @@ by calling `linarith.removeNe_aux`.
This produces `2^n` branches when there are `n` such hypotheses in the input.
-/
def removeNe : GlobalBranchingPreprocessor where
name := "removeNe"
description := "case split on ≠"
transform := removeNe_aux
end removeNe

Expand All @@ -394,7 +395,10 @@ Note that a preprocessor may produce multiple or no expressions from each input
so the size of the list may change.
-/
def preprocess (pps : List GlobalBranchingPreprocessor) (g : MVarId) (l : List Expr) :
MetaM (List Branch) := g.withContext <|
pps.foldlM (fun ls pp => return (← ls.mapM fun (g, l) => do pp.process g l).flatten) [(g, l)]
MetaM (List Branch) := do
withTraceNode `linarith (fun e => return m!"{exceptEmoji e} Running preprocessors") <|
g.withContext <|
pps.foldlM (init := [(g, l)]) fun ls pp => do
return (← ls.mapM fun (g, l) => do pp.process g l).flatten

end Linarith
71 changes: 41 additions & 30 deletions Mathlib/Tactic/Linarith/Verification.lean
Original file line number Diff line number Diff line change
Expand Up @@ -191,41 +191,52 @@ def proveFalseByLinarith (transparency : TransparencyMode) (oracle : Certificate
(discharger : TacticM Unit) : MVarId → List Expr → MetaM Expr
| _, [] => throwError "no args to linarith"
| g, l@(h::_) => do
trace[linarith.detail] "Beginning work in `proveFalseByLinarith`."
Lean.Core.checkSystem decl_name%.toString
-- for the elimination to work properly, we must add a proof of `-1 < 0` to the list,
-- along with negated equality proofs.
let l' ← addNegEqProofs l
trace[linarith.detail] "... finished `addNegEqProofs`."
let inputs := (← mkNegOneLtZeroProof (← typeOfIneqProof h))::l'.reverse
trace[linarith.detail] "... finished `mkNegOneLtZeroProof`."
trace[linarith.detail] (← inputs.mapM inferType)
let (comps, max_var) ← linearFormsAndMaxVar transparency inputs
trace[linarith.detail] "... finished `linearFormsAndMaxVar`."
trace[linarith.detail] "{comps}"
let l' ← detailTrace "addNegEqProofs" <| addNegEqProofs l
let inputs ← detailTrace "mkNegOneLtZeroProof" <|
return (← mkNegOneLtZeroProof (← typeOfIneqProof h))::l'.reverse
trace[linarith.detail] "inputs:{indentD <| toMessageData (← inputs.mapM inferType)}"
let (comps, max_var) ← detailTrace "linearFormsAndMaxVar" <|
linearFormsAndMaxVar transparency inputs
trace[linarith.detail] "comps:{indentD <| toMessageData comps}"
-- perform the elimination and fail if no contradiction is found.
let certificate : Std.HashMap Nat Nat ← try
oracle.produceCertificate comps max_var
catch e =>
trace[linarith] e.toMessageData
throwError "linarith failed to find a contradiction"
trace[linarith] "linarith has found a contradiction: {certificate.toList}"
let enum_inputs := inputs.enum
-- construct a list pairing nonzero coeffs with the proof of their corresponding comparison
let zip := enum_inputs.filterMap fun ⟨n, e⟩ => (certificate[n]?).map (e, ·)
let mls ← zip.mapM fun ⟨e, n⟩ => do mulExpr n (← leftOfIneqProof e)
-- `sm` is the sum of input terms, scaled to cancel out all variables.
let sm ← addExprs mls
-- let sm ← instantiateMVars sm
trace[linarith] "The expression\n {sm}\nshould be both 0 and negative"
let certificate : Std.HashMap Nat Nat ←
withTraceNode `linarith (return m!"{exceptEmoji ·} Invoking oracle") do
let certificate ←
try
oracle.produceCertificate comps max_var
catch e =>
trace[linarith] e.toMessageData
throwError "linarith failed to find a contradiction"
trace[linarith] "found a contradiction: {certificate.toList}"
return certificate
let (sm, zip) ←
withTraceNode `linarith (return m!"{exceptEmoji ·} Building final expression") do
let enum_inputs := inputs.enum
-- construct a list pairing nonzero coeffs with the proof of their corresponding
-- comparison
let zip := enum_inputs.filterMap fun ⟨n, e⟩ => (certificate[n]?).map (e, ·)
let mls ← zip.mapM fun ⟨e, n⟩ => do mulExpr n (← leftOfIneqProof e)
-- `sm` is the sum of input terms, scaled to cancel out all variables.
let sm ← addExprs mls
-- let sm ← instantiateMVars sm
trace[linarith] "{indentD sm}\nshould be both 0 and negative"
return (sm, zip)
-- we prove that `sm = 0`, typically with `ring`.
let sm_eq_zero ← proveEqZeroUsing discharger sm
let sm_eq_zero ← detailTrace "proveEqZeroUsing" <| proveEqZeroUsing discharger sm
-- we also prove that `sm < 0`
let sm_lt_zero ← mkLTZeroProof zip
-- this is a contradiction.
let pftp ← inferType sm_lt_zero
let ⟨_, nep, _⟩ ← g.rewrite pftp sm_eq_zero
let pf' ← mkAppM ``Eq.mp #[nep, sm_lt_zero]
mkAppM ``Linarith.lt_irrefl #[pf']
let sm_lt_zero ← detailTrace "mkLTZeroProof" <| mkLTZeroProof zip
detailTrace "Linarith.lt_irrefl" do
-- this is a contradiction.
let pftp ← inferType sm_lt_zero
let ⟨_, nep, _⟩ ← g.rewrite pftp sm_eq_zero
let pf' ← mkAppM ``Eq.mp #[nep, sm_lt_zero]
mkAppM ``Linarith.lt_irrefl #[pf']
where
/-- Log `f` under `linarith.detail`, with exception emojis and the provided name. -/
detailTrace {α} (s : String) (f : MetaM α) : MetaM α :=
withTraceNode `linarith.detail (return m!"{exceptEmoji ·} {s}") f

end Linarith

0 comments on commit 3a798ee

Please sign in to comment.