Skip to content

Commit

Permalink
proper elaborator for lsimp tactic
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Jun 18, 2024
1 parent 4b5699f commit 3f5d37d
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 92 deletions.
50 changes: 34 additions & 16 deletions SciLean/Data/StructType/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,24 @@ instance (priority:=low) instStructTypeDefault : StructType α Unit (fun _ => α
theorem oneHot_unit {X} [Zero X] (x : X)
: oneHot (X:=X) (I:=Unit) () x = x := by rfl

@[simp, ftrans_simp]
theorem structProj_unit (x : E)
: structProj (I:=Unit) x ()
=
x := rfl

@[simp, ftrans_simp]
theorem structMake_unit (f : Unit → E)
: structMake (I:=Unit) f
=
f () := rfl

@[simp, ftrans_simp]
theorem structModify_unit (f : E → E) (x : E)
: structModify (I:=Unit) () f x
=
f x := rfl


-- Pi --------------------------------------------------------------------------
--------------------------------------------------------------------------------
Expand Down Expand Up @@ -219,15 +237,15 @@ instance instStrucTypeSigma
-- by
-- simp[structMake]

-- @[simp low, ftrans_simp low]
-- theorem structModify_inl [StructType E I EI] [StructType F J FJ] (i : I) (f : EI i → EI i) (xy : E×F)
-- : structModify (I:=I⊕J) (.inl i) f xy
-- =
-- {xy with fst := structModify i f xy.1} :=
-- by
-- conv =>
-- lhs
-- simp[structModify]
@[simp low, ftrans_simp low]
theorem structModify_inl [StructType E I EI] [StructType F J FJ] (i : I) (f : EI i → EI i) (xy : E×F)
: structModify (I:=I⊕J) (.inl i) f xy
=
{xy with fst := structModify i f xy.1} :=
by
conv =>
lhs
simp[structModify]

-- @[simp, ftrans_simp]
-- theorem structModify_inl' [StructType E I EI] [StructType F J FJ] (i : I) (f : EI i → EI i) (x : E) (y : F)
Expand All @@ -239,13 +257,13 @@ instance instStrucTypeSigma
-- lhs
-- simp[structModify]

-- @[simp low, ftrans_simp low]
-- theorem structModify_inr [StructType E I EI] [StructType F J FJ] (j : J) (f : FJ j → FJ j) (xy : E×F)
-- : structModify (I:=I⊕J) (.inr j) f xy
-- =
-- (xy.1, structModify j f xy.2) :=
-- by
-- simp[structModify]
@[simp low, ftrans_simp low]
theorem structModify_inr [StructType E I EI] [StructType F J FJ] (j : J) (f : FJ j → FJ j) (xy : E×F)
: structModify (I:=I⊕J) (.inr j) f xy
=
(xy.1, structModify j f xy.2) :=
by
simp[structModify]

-- @[simp, ftrans_simp]
-- theorem structModify_inr' [StructType E I EI] [StructType F J FJ] (j : J) (f : FJ j → FJ j) (x : E) (y : F)
Expand Down
67 changes: 18 additions & 49 deletions SciLean/Tactic/LSimp/Elab.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,56 +9,25 @@ open TSyntax.Compat
open Lean Meta


open Lean.Parser.Tactic
syntax (name := Parser.lsimp) "lsimp" (config)? (discharger)? (&" only")?
(" [" withoutPosition((simpStar <|> simpErase <|> simpLemma),*,?) "]")? : tactic


def callLSimpAux (e : Expr) (k : Expr → Expr → Array Expr → MetaM α) : MetaM α := do

let stateRef : IO.Ref Simp.State ← IO.mkRef {}
let lcacheRef : IO.Ref Cache ← IO.mkRef {}

let mut simprocs : Simp.Simprocs := {}
simprocs ← simprocs.add ``Mathlib.Meta.FunTrans.fun_trans_simproc false
let .some ext ← getSimpExtension? `ftrans_simp | throwError "can't find theorems!"
let thms ← ext.getTheorems
>>= (·.addDeclToUnfold ``scalarGradient)
>>= (·.addDeclToUnfold ``scalarCDeriv)

let r :=
(lsimp e).run
(Simp.mkDefaultMethodsCore #[simprocs])
{config:={zeta:=false,singlePass:=false},simpTheorems:=#[thms]}
⟨lcacheRef, stateRef, {}⟩

let (a,t) ← Aesop.time <| r.runInMeta (fun (r,s) => do
trace[Meta.Tactic.simp.numSteps] "{(← stateRef.get).numSteps}"
s.printTimings

-- IO.println "cache"
-- (← s.simpState.get).cache.forM fun e v => do
-- IO.println s!"{← ppExpr e}"

k r.expr (← r.getProof) r.vars)

trace[Meta.Tactic.simp.time] "lsimp took {t.printAsMillis}"

return a


def callLSimp (e : Expr) : MetaM (Expr×Expr) := do
callLSimpAux e (fun e prf vars => do
return (← mkLambdaFVars vars e, ← mkLambdaFVars vars prf))



open Lean.Parser.Tactic in
syntax (name:=lsimp_conv) "lsimp" /-(config)? (discharger)? (normalizer)?-/ : conv
syntax (name:=lsimp_conv) "lsimp" (config)? (discharger)? (&" only")?
(" [" withoutPosition((simpStar <|> simpErase <|> simpLemma),*,?) "]")? : conv


open Lean Elab Tactic in
@[tactic lsimp_conv] unsafe def lsimpConv : Tactic := fun _ => do
let e ← Conv.getLhs
let (e',prf) ← callLSimp e
Conv.updateLhs e' prf
@[tactic lsimp_conv] unsafe def lsimpConv : Tactic := fun stx => do
withMainContext do withSimpDiagnostics do
let { ctx, simprocs, dischargeWrapper } ← mkSimpContext stx (eraseLocal := false)
let ctx := { ctx with config := { ctx.config with zeta := false } }
let stats ← dischargeWrapper.with fun discharge? => do
let e ← Conv.getLhs
let ((e',prf),stats) ←
lsimpMain e /- k -/ ctx simprocs discharge?
(k := fun r => do let r ← r.bindVars; pure (r.expr, ← r.getProof))
Conv.updateLhs e' prf
return stats

if tactic.simp.trace.get (← getOptions) then
traceSimpCall stx stats.usedTheorems

return stats.diag
32 changes: 32 additions & 0 deletions SciLean/Tactic/LSimp/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -726,3 +726,35 @@ where

trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}"
simpLoop e


/-- Run `lsimp` on `e` and process result with `k r' where `k` is executed in modified local context
where all `r.vars` are valid free vars.
-/
def main (e : Expr) (k : Result → MetaM α)
(ctx : Simp.Context)
(stats : Simp.Stats := {})
(methods : Simp.Methods := {}) : MetaM (α × Simp.Stats) := do

-- prepare state
let lcacheRef : IO.Ref Cache ← IO.mkRef {}
let stateRef : IO.Ref Simp.State ← IO.mkRef {stats with}
let state : State := { cache := lcacheRef, simpState := stateRef }

-- load context
let ctx := { ctx with config := (← ctx.config.updateArith), lctxInitIndices := (← getLCtx).numIndices }
Simp.withSimpContext ctx do

let (a,s) ← (lsimp e methods ctx state).runInMeta
(fun (r,s) => do pure (← k r,s))

let simpState ← s.simpState.get
return (a, {simpState with})


def lsimpMain (e : Expr) (k : Result → MetaM α)
(ctx : Simp.Context) (simprocs : Simp.SimprocsArray := #[]) (discharge? : Option Simp.Discharge := none)
(stats : Simp.Stats := {}) : MetaM (α × Simp.Stats) := do profileitM Exception "lsimp" (← getOptions) do
match discharge? with
| none => main e k ctx stats (methods := Simp.mkDefaultMethodsCore simprocs)
| some d => main e k ctx stats (methods := Simp.mkMethods simprocs d (wellBehavedDischarge := false))
7 changes: 4 additions & 3 deletions SciLean/Tactic/LSimp/Test.lean
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ example (n : Nat) :
a + b)
=
n + (n + 3) + n + 2 + (n + (n + 3) + n + 2 + 5) := by
(conv => lhs; lsimp)
(conv => lhs; lsimp (config:={zeta:=false}))

example (n : Nat) (i : Fin n) :
(let j := 2*i.1
Expand All @@ -186,7 +186,7 @@ example (n : Nat) (i : Fin n) :
let hj : j < 2*n := by omega
let j : Fin (2*n) := ⟨j, hj⟩
(j + (j + j + j)) := by
(conv => lhs; lsimp)
(conv => lhs; lsimp (config:={zeta:=false}))

-- tests under lambda binder

Expand All @@ -195,6 +195,7 @@ example :
=
(fun n : Nat => n) := by (conv => lhs; lsimp)


example :
(fun n => let a := 1; a + n)
=
Expand All @@ -214,4 +215,4 @@ example :
a)
=
(fun n => n + n) := by
(conv => lhs; lsimp)
(conv => lhs; lsimp (config:={zeta:=false}))
1 change: 0 additions & 1 deletion SciLean/Tactic/LSimp/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ structure State where
simpState : IO.Ref Simp.State
timings : Batteries.RBMap String Aesop.Nanos compare := {}


abbrev LSimpM := ReaderT Simp.Methods $ ReaderT Simp.Context $ StateT State MetaLCtxM

instance : MonadLift SimpM LSimpM where
Expand Down
48 changes: 25 additions & 23 deletions test/lsimp_basic_tests.lean
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,16 @@ elab "timeTactic" t:conv : conv => do
IO.println s!"tactic {t.raw.prettyPrint} took {time.printAsMillis}"



#check (∇ x : Float, let y := x * x; x * y)
rewrite_by
unfold scalarGradient
lsimp
lsimp only [Mathlib.Meta.FunTrans.fun_trans_simproc, ftrans_simp]


macro "lautodiff" : conv =>
`(conv| (unfold scalarGradient scalarCDeriv;
lsimp (config := {zeta:=false}) only
[Mathlib.Meta.FunTrans.fun_trans_simproc, ftrans_simp]))

-- set_option trace.Meta.Tactic.fun_trans true in

Expand All @@ -145,8 +149,9 @@ elab "timeTactic" t:conv : conv => do
-- unfold scalarGradient
-- timeTactic lsimp


set_option trace.Meta.Tactic.fun_trans true in
-- set_option trace.Meta.Tactic.fun_trans true in
set_option trace.Meta.Tactic.simp.unify true in
set_option trace.Meta.Tactic.simp.rewrite true in
#check (∇ x : Float,
let x1 := x * x
let x2 := x * x1
Expand All @@ -161,28 +166,25 @@ set_option trace.Meta.Tactic.fun_trans true in
-- let x11 := x * x10
x6)
rewrite_by
unfold scalarGradient
lsimp

unfold scalarGradient scalarCDeriv
lautodiff

#exit

#check (∂ x : Float, let y := x * x; x * y)
rewrite_by
unfold scalarCDeriv
lsimp
lautodiff


#check (∂> x : Float, let y := x * x; x * y)
rewrite_by
lsimp
lautodiff


-- #check Nat

#check (∂> x : Float, let y := x * x; let z := x * y; x * y * z)
rewrite_by
lsimp
lautodiff

set_option profiler true

Expand All @@ -200,22 +202,22 @@ open SciLean
#check (structProj 1 ())
rewrite_by
unfold scalarGradient
lsimp
lsimp
lautodiff
lautodiff



-- #check (∇ x : Float, let y := x * x; let z := x * y; x * y * z)
-- rewrite_by
-- unfold scalarGradient
-- lsimp
-- lsimp-- 16107 steps & 4.45s | with cache: 1295 steps && 709ms
-- lautodiff
-- lautodiff-- 16107 steps & 4.45s | with cache: 1295 steps && 709ms


-- #check (∇ x : Float, let y := x * x; let z := x * y; x * y * z)
-- rewrite_by
-- unfold scalarGradient
-- lsimp -- 16107 steps & 4.45s | with cache: 1295 steps && 709ms
-- lautodiff -- 16107 steps & 4.45s | with cache: 1295 steps && 709ms


#check (∇ x : Float,
Expand All @@ -232,7 +234,7 @@ open SciLean
x6)
rewrite_by
unfold scalarGradient
lsimp
lautodiff



Expand All @@ -251,7 +253,7 @@ open SciLean
x6)
rewrite_by
unfold scalarGradient
lsimp
lautodiff



Expand All @@ -269,7 +271,7 @@ open SciLean
x6)
rewrite_by
unfold scalarGradient
lsimp
lautodiff



Expand All @@ -289,7 +291,7 @@ open SciLean
-- x10 * x11)
-- rewrite_by
-- unfold scalarGradient
-- lsimp
-- lautodiff



Expand All @@ -306,7 +308,7 @@ open SciLean
x3 * x4)
rewrite_by
unfold scalarGradient
lsimp
lautodiff
simp


Expand All @@ -321,7 +323,7 @@ open SciLean
x * x1 * x2 * x3 * x4 * x5)
rewrite_by
unfold scalarGradient
lsimp
lautodiff



Expand Down

0 comments on commit 3f5d37d

Please sign in to comment.