Skip to content

Commit

Permalink
new optConfig syntax for autodiff
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 4, 2024
1 parent 082e0c4 commit 337ca8c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
16 changes: 8 additions & 8 deletions SciLean/Tactic/Autodiff.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,30 @@ import SciLean.Tactic.LFunTrans
namespace SciLean.Tactic

open Lean Meta Elab Tactic Mathlib.Meta.FunTrans Lean.Parser.Tactic in
syntax (name := lautodiffConvStx) "autodiff" (config)? (discharger)?
syntax (name := lautodiffConvStx) "autodiff" optConfig (discharger)?
(" [" withoutPosition((simpStar <|> simpErase <|> simpLemma),*) "]")? : conv

open Lean Meta Elab Tactic Mathlib.Meta.FunTrans Lean.Parser.Tactic in
syntax (name := lautodiffTacticStx) "autodiff" (config)? (discharger)?
syntax (name := lautodiffTacticStx) "autodiff" optConfig (discharger)?
(" [" withoutPosition((simpStar <|> simpErase <|> simpLemma),*) "]")? : tactic

macro_rules
| `(conv| autodiff $[$cfg]? $[$disch]? $[[$a,*]]?) => do
| `(conv| autodiff $cfg $[$disch]? $[[$a,*]]?) => do
if a.isSome then
`(conv| ((try unfold deriv fgradient adjointFDeriv); -- todo: investigate why simp sometimes does not unfold and remove this line
lfun_trans $[$cfg]? $[$disch]? only $[[deriv, fgradient, adjointFDeriv, simp_core, $a,*]]?))
lfun_trans $cfg $[$disch]? only $[[deriv, fgradient, adjointFDeriv, simp_core, $a,*]]?))
else
`(conv| ((try unfold deriv fgradient adjointFDeriv);
lfun_trans $[$cfg]? $[$disch]? only [deriv, fgradient, adjointFDeriv, simp_core]))
lfun_trans $cfg $[$disch]? only [deriv, fgradient, adjointFDeriv, simp_core]))

macro_rules
| `(tactic| autodiff $[$cfg]? $[$disch]? $[[$a,*]]?) => do
| `(tactic| autodiff $cfg $[$disch]? $[[$a,*]]?) => do
if a.isSome then
`(tactic| ((try unfold deriv fgradient adjointFDeriv);
lfun_trans $[$cfg]? $[$disch]? only $[[deriv, fgradient, adjointFDeriv, simp_core, $a,*]]?))
lfun_trans $cfg $[$disch]? only $[[deriv, fgradient, adjointFDeriv, simp_core, $a,*]]?))
else
`(tactic| ((try unfold deriv fgradient adjointFDeriv);
lfun_trans $[$cfg]? $[$disch]? only [deriv, fgradient, adjointFDeriv, simp_core]))
lfun_trans $cfg $[$disch]? only [deriv, fgradient, adjointFDeriv, simp_core]))

-- open Lean Meta
-- simproc_decl lift_lets_simproc (_) := fun e => do
Expand Down
16 changes: 8 additions & 8 deletions SciLean/Tactic/LFunTrans.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,27 @@ namespace SciLean.Tactic
open Lean Meta Elab Tactic Mathlib.Meta.FunTrans Lean.Parser.Tactic


syntax (name := lfunTransTacStx) "lfun_trans" (config)? (discharger)? (&" only")?
syntax (name := lfunTransTacStx) "lfun_trans" optConfig (discharger)? (&" only")?
(" [" withoutPosition((simpStar <|> simpErase <|> simpLemma),*,?) "]")? (location)? : tactic

syntax (name := lfunTransConvStx) "lfun_trans" (config)? (discharger)? (&" only")?
syntax (name := lfunTransConvStx) "lfun_trans" optConfig (discharger)? (&" only")?
(" [" withoutPosition((simpStar <|> simpErase <|> simpLemma),*) "]")? : conv


@[tactic lfunTransTacStx]
def lfunTransTac : Tactic := fun stx => do
match stx with
| `(tactic| lfun_trans $[$cfg]? $[$disch]? $[only]? $[[$a,*]]? $[$loc]?) => do
| `(tactic| lfun_trans $cfg $[$disch]? $[only]? $[[$a,*]]? $[$loc]?) => do

-- set fun_trans config
funTransContext.modify
fun c => { c with funPropContext := { c.funPropContext with disch := stxToDischarge disch}}

let a := a.getD (Syntax.TSepArray.mk #[])
if stx[3].isNone then
evalTactic (← `(tactic| lsimp $[$cfg]? $[$disch]? [↓fun_trans_simproc,$a,*]))
evalTactic (← `(tactic| lsimp $cfg $[$disch]? [↓fun_trans_simproc,$a,*]))
else
evalTactic (← `(tactic| lsimp $[$cfg]? $[$disch]? only [↓fun_trans_simproc,$a,*]))
evalTactic (← `(tactic| lsimp $cfg $[$disch]? only [↓fun_trans_simproc,$a,*]))

-- reset fun_trans config
funTransContext.modify fun _ => {}
Expand All @@ -41,17 +41,17 @@ def lfunTransTac : Tactic := fun stx => do
@[tactic lfunTransConvStx]
def lfunTransConv : Tactic := fun stx => do
match stx with
| `(conv| lfun_trans $[$cfg]? $[$disch]? $[only]? $[[$a,*]]?) => do
| `(conv| lfun_trans $cfg $[$disch]? $[only]? $[[$a,*]]?) => do

-- set fun_trans config
funTransContext.modify
fun c => { c with funPropContext := { c.funPropContext with disch := stxToDischarge disch}}

let a := a.getD (Syntax.TSepArray.mk #[])
if stx[3].isNone then
evalTactic (← `(conv| lsimp $[$cfg]? $[$disch]? [↓fun_trans_simproc,$a,*]))
evalTactic (← `(conv| lsimp $cfg $[$disch]? [↓fun_trans_simproc,$a,*]))
else
evalTactic (← `(conv| lsimp $[$cfg]? $[$disch]? only [↓fun_trans_simproc,$a,*]))
evalTactic (← `(conv| lsimp $cfg $[$disch]? only [↓fun_trans_simproc,$a,*]))

-- reset fun_trans config
funTransContext.modify fun _ => {}
Expand Down
4 changes: 2 additions & 2 deletions SciLean/Tactic/LSimp/Elab.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ open Lean Meta


open Lean.Parser.Tactic in
syntax (name:=lsimp_conv) "lsimp" (config)? (discharger)? (&" only")?
syntax (name:=lsimp_conv) "lsimp" optConfig (discharger)? (&" only")?
(" [" withoutPosition((simpStar <|> simpErase <|> simpLemma),*,?) "]")? : conv


open Lean.Parser.Tactic in
syntax (name:=lsimp_tactic) "lsimp" (config)? (discharger)? (&" only")?
syntax (name:=lsimp_tactic) "lsimp" optConfig (discharger)? (&" only")?
(" [" withoutPosition((simpStar <|> simpErase <|> simpLemma),*,?) "]")? : tactic


Expand Down

0 comments on commit 337ca8c

Please sign in to comment.