diff --git a/SciLean/Tactic/DataSynth/Main.lean b/SciLean/Tactic/DataSynth/Main.lean index ec3fd023..28aabcf9 100644 --- a/SciLean/Tactic/DataSynth/Main.lean +++ b/SciLean/Tactic/DataSynth/Main.lean @@ -12,6 +12,10 @@ open Lean Meta private def withProfileTrace (msg : String) (x : DataSynthM α) : DataSynthM α := withTraceNode `Meta.Tactic.data_synth.profile (fun _ => return msg) x +private def withMainTrace (msg : Except Exception α → DataSynthM MessageData) (x : DataSynthM α) : + DataSynthM α := + withTraceNode `Meta.Tactic.data_synth msg x + def Simp.lsimp (e : Expr) : SimpM Simp.Result := let r := do @@ -183,8 +187,7 @@ partial def normalizeCore (e : Expr) : DataSynthM Expr := def normalize (e : Expr) : DataSynthM (Simp.Result) := do - withTraceNode - `Meta.Tactic.data_synth + withMainTrace (fun _ => return m!"normalization") do let cfg := (← read).config @@ -316,10 +319,8 @@ def synthesizeArgument (x : Expr) : DataSynthM Bool := do /- -/ def tryTheorem? (e : Expr) (thm : DataSynthTheorem) : DataSynthM (Option Expr) := do - withProfileTrace "tryTheorem" do - withTraceNode - `Meta.Tactic.data_synth + withMainTrace (fun r => return m!"[{ExceptToEmoji.toEmoji r}] applying {← ppOrigin (.decl thm.thmName)}") do let thmProof ← thm.getProof @@ -336,9 +337,6 @@ def tryTheorem? (e : Expr) (thm : DataSynthTheorem) : DataSynthM (Option Expr) : for x in xs do let _ ← synthesizeArgument x - -- for x in xs do - -- let _ ← synthesizeArgument x - -- check if all arguments have been synthesized for x in xs do let x ← instantiateMVars x @@ -349,6 +347,22 @@ def tryTheorem? (e : Expr) (thm : DataSynthTheorem) : DataSynthM (Option Expr) : return some thmProof +def Goal.tryTheorem? (goal : Goal) (thm : DataSynthTheorem) (normalize := true) : DataSynthM (Option Result) := do + withProfileTrace "tryTheorem" do + + let (xs, e) ← goal.mkFreshProofGoal + + let .some prf ← DataSynth.tryTheorem? e thm | return none + + let mut r := Result.mk xs prf goal + + if normalize then + r ← r.normalize + + return r + + + -- main function that looks up theorems partial def main (goal : Goal) : DataSynthM (Option Result) := do withProfileTrace "main" do @@ -362,18 +376,7 @@ partial def main (goal : Goal) : DataSynthM (Option Result) := do trace[Meta.Tactic.data_synth] "candidates {thms.map (fun t => t.thmName)}" for thm in thms do - -- for each theorem we generate a fresh data mvars `xs` because them might get partially filled - -- when unsuccesfully trying a theorem - let (xs, e) ← goal.mkFreshProofGoal - if let .some prf ← tryTheorem? e thm then - -- result - let r := Result.mk xs prf goal - - -- normalize synthsized data - let rs ← xs.mapM (fun x => instantiateMVars x >>= normalize) - - -- fix proof - let r ← r.congr rs + if let .some r ← goal.tryTheorem? thm then return r return none @@ -396,7 +399,7 @@ def mainCached (goal : Goal) (initialTrace := true) : DataSynthM (Option Result) return none if initialTrace then - withTraceNode `Meta.Tactic.data_synth + withMainTrace (fun r => match r with | .ok (some _r) => return m!"[✅] {← goal.pp}" @@ -442,7 +445,7 @@ def letGoals (fgGoal : Goal) (f g : Expr) : DataSynthM (Option (Goal×Goal)) := let (xs, _, thm) ← forallMetaTelescope info.type try - withTraceNode `Meta.Tactic.data_synth (fun _ => return m!"assigning data") do + withMainTrace (fun _ => return m!"assigning data") do xs[gId]!.mvarId!.assignIfDefeq g xs[fId]!.mvarId!.assignIfDefeq f catch _ => @@ -471,13 +474,9 @@ def letResults (fgGoal : Goal) (f g : Expr) (hf hg : Result) : DataSynthM (Optio args? := args?.set! hgId hg.proof args? := args?.set! hfId hf.proof - let proof ← - withTraceNode `Meta.Tactic.data_synth (fun _ => return m!"theorem application") do - mkAppOptM thmName args? + let proof ← mkAppOptM thmName args? - let r ← - withTraceNode `Meta.Tactic.data_synth (fun _ => return m!"getting result from proof") do - fgGoal.getResultFrom proof + let r ← fgGoal.getResultFrom proof return r @@ -553,6 +552,32 @@ def projResults (fGoal : Goal) (f g p₁ p₂ q : Expr) (hg : Result) : DataSynt return r +def constCase? (goal : Goal) (f : FunData) : DataSynthM (Option Result) := do + + let vars := (← f.body.collectFVars |>.run {}).2.fvarSet + let (xs₁, xs₂) := f.xs.split (fun x => vars.contains x.fvarId!) + + unless xs₁.size = 0 do return none + withProfileTrace "const case" do + withMainTrace (fun _ => return "constant function") do + + let (xs, e) ← goal.mkFreshProofGoal + + let thm : DataSynthTheorem ← + getTheoremFromConst (goal.dataSynthDecl.name.append `const_rule) + + let .some prf ← tryTheorem? e thm | return none + + -- result + let r := Result.mk xs prf goal + + -- fix proof + let r ← r.normalize + return r + + + + def decomposeDomain? (goal : Goal) (f : FunData) : DataSynthM (Option Result) := do if ¬(← read).config.domainDec then return none @@ -601,6 +626,10 @@ def lamCase (goal : Goal) (f : FunData) : DataSynthM (Option Result) := do partial def mainFun (goal : Goal) (f : FunData) : DataSynthM (Option Result) := do withProfileTrace "mainFun" do + -- spacial case for constant functions + if let some r ← constCase? goal f then + return r + -- decompose domain if possible if let some r ← decomposeDomain? goal f then return r @@ -623,7 +652,7 @@ partial def mainFun (goal : Goal) (f : FunData) : DataSynthM (Option Result) := def mainFunCached (goal : Goal) (f : FunData) : DataSynthM (Option Result) := do - withTraceNode `Meta.Tactic.data_synth + withMainTrace (fun r => match r with | .ok (some r) => return m!"[✅] {← goal.pp}" diff --git a/SciLean/Tactic/DataSynth/Types.lean b/SciLean/Tactic/DataSynth/Types.lean index e35c357c..47278c50 100644 --- a/SciLean/Tactic/DataSynth/Types.lean +++ b/SciLean/Tactic/DataSynth/Types.lean @@ -64,7 +64,6 @@ def congr (r : Result) (rs : Array Simp.Result) : MetaM Result := do -- proof that original result is equal to the result with normalized data let hgoal ← - withTraceNode `Meta.Tactic.data_synth (fun _ => return m!"goal congr fold") do (r.xs.zip rs).foldlM (init:= ← mkEqRefl goal) (fun g (x,r) => match r.proof? with