Skip to content

Commit

Permalink
special constant function case in data_synth and tracing clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 5, 2024
1 parent 54431cf commit df6e8a2
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 30 deletions.
87 changes: 58 additions & 29 deletions SciLean/Tactic/DataSynth/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ open Lean Meta
private def withProfileTrace (msg : String) (x : DataSynthM α) : DataSynthM α :=
withTraceNode `Meta.Tactic.data_synth.profile (fun _ => return msg) x

private def withMainTrace (msg : Except Exception α → DataSynthM MessageData) (x : DataSynthM α) :
DataSynthM α :=
withTraceNode `Meta.Tactic.data_synth msg x


def Simp.lsimp (e : Expr) : SimpM Simp.Result :=
let r := do
Expand Down Expand Up @@ -183,8 +187,7 @@ partial def normalizeCore (e : Expr) : DataSynthM Expr :=

def normalize (e : Expr) : DataSynthM (Simp.Result) := do

withTraceNode
`Meta.Tactic.data_synth
withMainTrace
(fun _ => return m!"normalization") do

let cfg := (← read).config
Expand Down Expand Up @@ -316,10 +319,8 @@ def synthesizeArgument (x : Expr) : DataSynthM Bool := do
/-
-/
def tryTheorem? (e : Expr) (thm : DataSynthTheorem) : DataSynthM (Option Expr) := do
withProfileTrace "tryTheorem" do

withTraceNode
`Meta.Tactic.data_synth
withMainTrace
(fun r => return m!"[{ExceptToEmoji.toEmoji r}] applying {← ppOrigin (.decl thm.thmName)}") do

let thmProof ← thm.getProof
Expand All @@ -336,9 +337,6 @@ def tryTheorem? (e : Expr) (thm : DataSynthTheorem) : DataSynthM (Option Expr) :
for x in xs do
let _ ← synthesizeArgument x

-- for x in xs do
-- let _ ← synthesizeArgument x

-- check if all arguments have been synthesized
for x in xs do
let x ← instantiateMVars x
Expand All @@ -349,6 +347,22 @@ def tryTheorem? (e : Expr) (thm : DataSynthTheorem) : DataSynthM (Option Expr) :
return some thmProof


def Goal.tryTheorem? (goal : Goal) (thm : DataSynthTheorem) (normalize := true) : DataSynthM (Option Result) := do
withProfileTrace "tryTheorem" do

let (xs, e) ← goal.mkFreshProofGoal

let .some prf ← DataSynth.tryTheorem? e thm | return none

let mut r := Result.mk xs prf goal

if normalize then
r ← r.normalize

return r



-- main function that looks up theorems
partial def main (goal : Goal) : DataSynthM (Option Result) := do
withProfileTrace "main" do
Expand All @@ -362,18 +376,7 @@ partial def main (goal : Goal) : DataSynthM (Option Result) := do
trace[Meta.Tactic.data_synth] "candidates {thms.map (fun t => t.thmName)}"

for thm in thms do
-- for each theorem we generate a fresh data mvars `xs` because them might get partially filled
-- when unsuccesfully trying a theorem
let (xs, e) ← goal.mkFreshProofGoal
if let .some prf ← tryTheorem? e thm then
-- result
let r := Result.mk xs prf goal

-- normalize synthsized data
let rs ← xs.mapM (fun x => instantiateMVars x >>= normalize)

-- fix proof
let r ← r.congr rs
if let .some r ← goal.tryTheorem? thm then
return r

return none
Expand All @@ -396,7 +399,7 @@ def mainCached (goal : Goal) (initialTrace := true) : DataSynthM (Option Result)
return none

if initialTrace then
withTraceNode `Meta.Tactic.data_synth
withMainTrace
(fun r =>
match r with
| .ok (some _r) => return m!"[✅] {← goal.pp}"
Expand Down Expand Up @@ -442,7 +445,7 @@ def letGoals (fgGoal : Goal) (f g : Expr) : DataSynthM (Option (Goal×Goal)) :=
let (xs, _, thm) ← forallMetaTelescope info.type

try
withTraceNode `Meta.Tactic.data_synth (fun _ => return m!"assigning data") do
withMainTrace (fun _ => return m!"assigning data") do
xs[gId]!.mvarId!.assignIfDefeq g
xs[fId]!.mvarId!.assignIfDefeq f
catch _ =>
Expand Down Expand Up @@ -471,13 +474,9 @@ def letResults (fgGoal : Goal) (f g : Expr) (hf hg : Result) : DataSynthM (Optio
args? := args?.set! hgId hg.proof
args? := args?.set! hfId hf.proof

let proof ←
withTraceNode `Meta.Tactic.data_synth (fun _ => return m!"theorem application") do
mkAppOptM thmName args?
let proof ← mkAppOptM thmName args?

let r ←
withTraceNode `Meta.Tactic.data_synth (fun _ => return m!"getting result from proof") do
fgGoal.getResultFrom proof
let r ← fgGoal.getResultFrom proof

return r

Expand Down Expand Up @@ -553,6 +552,32 @@ def projResults (fGoal : Goal) (f g p₁ p₂ q : Expr) (hg : Result) : DataSynt
return r


def constCase? (goal : Goal) (f : FunData) : DataSynthM (Option Result) := do

let vars := (← f.body.collectFVars |>.run {}).2.fvarSet
let (xs₁, xs₂) := f.xs.split (fun x => vars.contains x.fvarId!)

unless xs₁.size = 0 do return none
withProfileTrace "const case" do
withMainTrace (fun _ => return "constant function") do

let (xs, e) ← goal.mkFreshProofGoal

let thm : DataSynthTheorem ←
getTheoremFromConst (goal.dataSynthDecl.name.append `const_rule)

let .some prf ← tryTheorem? e thm | return none

-- result
let r := Result.mk xs prf goal

-- fix proof
let r ← r.normalize
return r




def decomposeDomain? (goal : Goal) (f : FunData) : DataSynthM (Option Result) := do
if ¬(← read).config.domainDec then
return none
Expand Down Expand Up @@ -601,6 +626,10 @@ def lamCase (goal : Goal) (f : FunData) : DataSynthM (Option Result) := do
partial def mainFun (goal : Goal) (f : FunData) : DataSynthM (Option Result) := do
withProfileTrace "mainFun" do

-- spacial case for constant functions
if let some r ← constCase? goal f then
return r

-- decompose domain if possible
if let some r ← decomposeDomain? goal f then
return r
Expand All @@ -623,7 +652,7 @@ partial def mainFun (goal : Goal) (f : FunData) : DataSynthM (Option Result) :=

def mainFunCached (goal : Goal) (f : FunData) : DataSynthM (Option Result) := do

withTraceNode `Meta.Tactic.data_synth
withMainTrace
(fun r =>
match r with
| .ok (some r) => return m!"[✅] {← goal.pp}"
Expand Down
1 change: 0 additions & 1 deletion SciLean/Tactic/DataSynth/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def congr (r : Result) (rs : Array Simp.Result) : MetaM Result := do

-- proof that original result is equal to the result with normalized data
let hgoal ←
withTraceNode `Meta.Tactic.data_synth (fun _ => return m!"goal congr fold") do
(r.xs.zip rs).foldlM (init:= ← mkEqRefl goal)
(fun g (x,r) =>
match r.proof? with
Expand Down

0 comments on commit df6e8a2

Please sign in to comment.