Skip to content

Commit

Permalink
progress on data_synth tactic
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 6, 2024
1 parent 8923c6f commit 34985a5
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 9 deletions.
8 changes: 8 additions & 0 deletions SciLean/Lean/Meta/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ def etaExpandN (e : Expr) (n : Nat) : MetaM Expr :=
def etaExpand' (e : Expr) : MetaM Expr :=
withDefault do forallTelescopeReducing (← inferType e) fun xs _ => mkLambdaFVars xs (mkAppN e xs).headBeta

/-- Ensures that function is eta expanded -/
def ensureEtaExpanded (e : Expr) : MetaM Expr := do
if e.isLambda then
return e
else
let .forallE n t _ _ ← inferType e | throwError "function expected"
return .lam n t (e.app (.bvar 0)) default


/--
Same as `mkAppM` but does not leave trailing implicit arguments.
Expand Down
3 changes: 2 additions & 1 deletion SciLean/Tactic/DataSynth/ArrayOperations.lean
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ theorem diagonal.arg_x.HasRevFDerivUpdate

@[data_synth]
theorem outerprod.arg_xy.HasRevFDerivUpdate
(x y : W → R^[I]) (x' y') (hx : HasRevFDerivUpdate R x x') (hy : HasRevFDerivUpdate R y y') :
(x : W → R^[I]) (y : W → R^[J]) (x' y') (hx : HasRevFDerivUpdate R x x') (hy : HasRevFDerivUpdate R y y') :
HasRevFDerivUpdate R
(fun w => (x w).outerprod (y w))
(fun w =>
Expand Down Expand Up @@ -551,6 +551,7 @@ set_option trace.Meta.Tactic.data_synth true in



set_option trace.Meta.Tactic.data_synth.input true in
set_option trace.Meta.Tactic.data_synth true in
#check (HasRevFDerivUpdate R (fun x : R^[I] => (∑ i, x[i])*‖x - ‖x‖₂²•1‖₂²) _)
rewrite_by
Expand Down
92 changes: 92 additions & 0 deletions SciLean/Tactic/DataSynth/DefRevDeriv.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import SciLean.Tactic.DataSynth.HasRevFDerivUpdate
import SciLean.Tactic.DataSynth.ArrayOperations
import SciLean.Tactic.DataSynth.Elab

namespace Scilean

open SciLean

open Lean Elab Command Meta

elab "def_rev_deriv" f:ident "in" args:ident* bs:bracketedBinder* "by" tac:tacticSeq : command => do

Elab.Command.liftTermElabM <| do
-- resolve function name
let fId ← ensureNonAmbiguous f (← resolveGlobalConst f)
let info ← getConstInfo fId

forallTelescope info.type fun xs _ => do
Elab.Term.elabBinders bs.raw fun ctx => do


let args := args.map (fun id => id.getId)
let (mainArgs, otherArgs) ← xs.splitM (fun x => do
let n ← x.fvarId!.getUserName
return args.contains n)

-- check if all arguments are present
for arg in args do
if ← mainArgs.allM (fun a => do pure ((← a.fvarId!.getUserName) != arg)) then
throwError s!"function `{fId}` does not have argument `{arg}`"

-- uncurry function appropriatelly
let lvls := info.levelParams.map (fun p => Level.param p)
let g ← liftM <|
mkLambdaFVars mainArgs (mkAppN (Expr.const info.name lvls) xs)
>>=
mkUncurryFun mainArgs.size

let some R ← xs.findSomeM? (fun x => do
let X ← inferType x
if X.isAppOf' ``SciLean.RealScalar then
return X.appArg!
else
return none)
| throwError "can't determine scalar"


let goal ← mkAppM ``HasRevFDerivUpdate #[R,g]
let (xs, _, goal') ← forallMetaTelescope (← inferType goal)
let goal := goal.beta xs

IO.println s!"initial: {← ppExpr goal}"

let m ← mkFreshExprMVar goal

let (_,_) ← runTactic m.mvarId! tac.raw

IO.println s!"result: {← ppExpr (← instantiateMVars goal)}"

let some goal ← Tactic.DataSynth.isDataSynthGoal? goal
| throwError "invalid goal"

pure ()


#check DataArrayN.outerprod

variable {R : Type} [RealScalar R] [PlainDataType R]
(y : R^[n])

#check (HasRevFDerivUpdate R (fun x : R^[n]×R^[n] => x.1.outerprod x.2) _)
rewrite_by
data_synth

#check SciLean.DataArrayN.outerprod.arg_xy.HasRevFDerivUpdate


def Q (q : R^[D]) (l : R^[((D-1)*D)/2]) : R^[D,D] := q.exp.diag + l.lowerTriangular D 1

--set_option trace.Meta.Tactic.data_synth true in
-- set_option trace.Meta.isDefEq true in
def_rev_deriv Q in q l by --
unfold Q
data_synth =>
enter[3]
simp -zeta [DataArrayN.diag]





#check Simp.Config
41 changes: 41 additions & 0 deletions SciLean/Tactic/DataSynth/Elab.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ open Parser.Tactic in
/-- `date_synth` as conv tactic will fill in meta variables in generalized transformation -/
syntax (name:=data_synth_conv) "data_synth" optConfig : conv


/- syntax (name := simp) "simp" optConfig (discharger)? (&" only")?
(" [" withoutPosition((simpStar <|> simpErase <|> simpLemma),*,?) "]")? (location)? : tactic
-/
Expand Down Expand Up @@ -50,3 +51,43 @@ syntax (name:=data_synth_conv) "data_synth" optConfig : conv
| none =>
throwError "`data_synth` failed"
| _ => throwUnsupportedSyntax



open Parser.Tactic Conv in
syntax (name:=data_synth_tac) "data_synth" optConfig ("=>" convSeq)? : tactic

@[tactic data_synth_tac] unsafe def dataSynthTactic : Tactic
| `(tactic| data_synth $cfg:optConfig $[=> $c]?) => do
let m ← getMainGoal
let e ← m.getType

let cfg ← elabDataSynthConfig cfg

let some g ← isDataSynthGoal? e
| throwError "{e} is not `data_synth` goal"

let stateRef : IO.Ref DataSynth.State ← IO.mkRef {}

let (r?,_) ← dataSynth g |>.run {config := cfg} |>.run stateRef
|>.run (← Simp.mkDefaultMethods).toMethodsRef
|>.run {config := cfg.toConfig, simpTheorems := #[← getSimpTheorems]}
|>.run {}

match r? with
| some r =>
let mut e' := r.getSolvedGoal
if let some c := c then
let (e'',eq) ← elabConvRewrite e' #[] (← `(conv| ($c)))
if ← isDefEq e e'' then
m.assign (← mkEqMP eq r.proof)
setGoals []
else
if ← isDefEq e e' then
m.assign r.proof
setGoals []
else
throwError "faield to assign data {e'}"
| none =>
throwError "`data_synth` failed"
| _ => throwUnsupportedSyntax
8 changes: 6 additions & 2 deletions SciLean/Tactic/DataSynth/HasRevFDerivUpdate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ theorem id_rule : HasRevFDerivUpdate R (fun x : X => x) (fun x => (x, fun dx dx
· fun_prop


@[data_synth]
theorem const_rule (y : Y) : HasRevFDerivUpdate R (fun x : X => y) (fun x => (y, fun _ dx => dx)) := by
constructor
· fun_trans
Expand Down Expand Up @@ -360,6 +359,11 @@ end OverReals


#exit
variable (f : X → X) (f') (hf : HasRevFDerivUpdate R f f')

#check (HasRevFDerivUpdate R (fun x => f x) _) rewrite_by data_synth

#check (HasRevFDerivUpdate R f _) rewrite_by data_synth


set_option trace.Meta.Tactic.data_synth true in
Expand Down Expand Up @@ -394,7 +398,7 @@ set_option pp.deepTerms.threshold 1000000000000000

#check (HasRevFDerivUpdate R (fun x : R => x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x) _ )
rewrite_by
data_synth -normalizeCore
data_synth

#check (HasRevFDerivUpdate R (fun x : R×R×R×R => x.1) _) rewrite_by
data_synth
Expand Down
29 changes: 25 additions & 4 deletions SciLean/Tactic/DataSynth/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,22 @@ where
| _ , [] => mkLambdaFVars fvars fn



def Goal.assumption? (goal : Goal) : DataSynthM (Option Result) := do
withProfileTrace "assumption?" do
(← getLCtx).findDeclRevM? fun localDecl => do
if localDecl.isImplementationDetail then
return none
else if localDecl.type.isAppOf' goal.dataSynthDecl.name then
let (_,e) ← goal.mkFreshProofGoal
if (← isDefEq e localDecl.type) then
return ← goal.getResultFrom (.fvar localDecl.fvarId)
else
return none
else
return none


def discharge? (e : Expr) : DataSynthM (Option Expr) := do
(← read).discharge e

Expand All @@ -291,6 +307,11 @@ def synthesizeArgument (x : Expr) : DataSynthM Bool := do
if let .some r ← do dataSynth g then
x.mvarId!.assignIfDefeq (← mkLambdaFVars ys r.proof)
return true

if let some r ← g.assumption? then
x.mvarId!.assignIfDefeq (← mkLambdaFVars ys r.proof)
return true

return false
if b then return true

Expand Down Expand Up @@ -371,16 +392,16 @@ partial def main (goal : Goal) : DataSynthM (Option Result) := do

let thms ← goal.getCandidateTheorems

if thms.size = 0 then
trace[Meta.Tactic.data_synth] "no applicable theorems"
return none

trace[Meta.Tactic.data_synth] "candidates {thms.map (fun t => t.thmName)}"

for thm in thms do
if let .some r ← goal.tryTheorem? thm then
return r

-- try local theorems
if let some r ← goal.assumption? then
return r

return none


Expand Down
5 changes: 3 additions & 2 deletions SciLean/Tactic/DataSynth/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import Lean
import SciLean.Tactic.LSimp.Main
import SciLean.Tactic.DataSynth.Decl
import SciLean.Lean.Meta.Uncurry
import SciLean.Lean.Meta.Basic

import Mathlib.Logic.Equiv.Defs

Expand Down Expand Up @@ -174,8 +175,8 @@ def curryLambdaTelescope (f : Expr) (k : Array Expr → Expr → MetaM α) : Met

lambdaTelescope f k

def getFunData (f : Expr) : MetaM FunData :=
curryLambdaTelescope f fun xs b => do
def getFunData (f : Expr) : MetaM FunData := do
curryLambdaTelescope (← ensureEtaExpanded f) fun xs b => do
let data : FunData :=
{ lctx := ← getLCtx
insts := ← getLocalInstances
Expand Down

0 comments on commit 34985a5

Please sign in to comment.