From 34985a51e4e8aea41dc6e0d5ecd78c31e527d967 Mon Sep 17 00:00:00 2001 From: lecopivo Date: Thu, 5 Dec 2024 20:25:41 -0500 Subject: [PATCH] progress on data_synth tactic --- SciLean/Lean/Meta/Basic.lean | 8 ++ SciLean/Tactic/DataSynth/ArrayOperations.lean | 3 +- SciLean/Tactic/DataSynth/DefRevDeriv.lean | 92 +++++++++++++++++++ SciLean/Tactic/DataSynth/Elab.lean | 41 +++++++++ .../Tactic/DataSynth/HasRevFDerivUpdate.lean | 8 +- SciLean/Tactic/DataSynth/Main.lean | 29 +++++- SciLean/Tactic/DataSynth/Types.lean | 5 +- 7 files changed, 177 insertions(+), 9 deletions(-) create mode 100644 SciLean/Tactic/DataSynth/DefRevDeriv.lean diff --git a/SciLean/Lean/Meta/Basic.lean b/SciLean/Lean/Meta/Basic.lean index 5e6643c1..d36aac34 100644 --- a/SciLean/Lean/Meta/Basic.lean +++ b/SciLean/Lean/Meta/Basic.lean @@ -113,6 +113,14 @@ def etaExpandN (e : Expr) (n : Nat) : MetaM Expr := def etaExpand' (e : Expr) : MetaM Expr := withDefault do forallTelescopeReducing (← inferType e) fun xs _ => mkLambdaFVars xs (mkAppN e xs).headBeta +/-- Ensures that function is eta expanded -/ +def ensureEtaExpanded (e : Expr) : MetaM Expr := do + if e.isLambda then + return e + else + let .forallE n t _ _ ← inferType e | throwError "function expected" + return .lam n t (e.app (.bvar 0)) default + /-- Same as `mkAppM` but does not leave trailing implicit arguments. diff --git a/SciLean/Tactic/DataSynth/ArrayOperations.lean b/SciLean/Tactic/DataSynth/ArrayOperations.lean index ffb7433d..e875cac2 100644 --- a/SciLean/Tactic/DataSynth/ArrayOperations.lean +++ b/SciLean/Tactic/DataSynth/ArrayOperations.lean @@ -79,7 +79,7 @@ theorem diagonal.arg_x.HasRevFDerivUpdate @[data_synth] theorem outerprod.arg_xy.HasRevFDerivUpdate - (x y : W → R^[I]) (x' y') (hx : HasRevFDerivUpdate R x x') (hy : HasRevFDerivUpdate R y y') : + (x : W → R^[I]) (y : W → R^[J]) (x' y') (hx : HasRevFDerivUpdate R x x') (hy : HasRevFDerivUpdate R y y') : HasRevFDerivUpdate R (fun w => (x w).outerprod (y w)) (fun w => @@ -551,6 +551,7 @@ set_option trace.Meta.Tactic.data_synth true in +set_option trace.Meta.Tactic.data_synth.input true in set_option trace.Meta.Tactic.data_synth true in #check (HasRevFDerivUpdate R (fun x : R^[I] => (∑ i, x[i])*‖x - ‖x‖₂²•1‖₂²) _) rewrite_by diff --git a/SciLean/Tactic/DataSynth/DefRevDeriv.lean b/SciLean/Tactic/DataSynth/DefRevDeriv.lean new file mode 100644 index 00000000..b9ad4b76 --- /dev/null +++ b/SciLean/Tactic/DataSynth/DefRevDeriv.lean @@ -0,0 +1,92 @@ +import SciLean.Tactic.DataSynth.HasRevFDerivUpdate +import SciLean.Tactic.DataSynth.ArrayOperations +import SciLean.Tactic.DataSynth.Elab + +namespace Scilean + +open SciLean + +open Lean Elab Command Meta + +elab "def_rev_deriv" f:ident "in" args:ident* bs:bracketedBinder* "by" tac:tacticSeq : command => do + + Elab.Command.liftTermElabM <| do + -- resolve function name + let fId ← ensureNonAmbiguous f (← resolveGlobalConst f) + let info ← getConstInfo fId + + forallTelescope info.type fun xs _ => do + Elab.Term.elabBinders bs.raw fun ctx => do + + + let args := args.map (fun id => id.getId) + let (mainArgs, otherArgs) ← xs.splitM (fun x => do + let n ← x.fvarId!.getUserName + return args.contains n) + + -- check if all arguments are present + for arg in args do + if ← mainArgs.allM (fun a => do pure ((← a.fvarId!.getUserName) != arg)) then + throwError s!"function `{fId}` does not have argument `{arg}`" + + -- uncurry function appropriatelly + let lvls := info.levelParams.map (fun p => Level.param p) + let g ← liftM <| + mkLambdaFVars mainArgs (mkAppN (Expr.const info.name lvls) xs) + >>= + mkUncurryFun mainArgs.size + + let some R ← xs.findSomeM? (fun x => do + let X ← inferType x + if X.isAppOf' ``SciLean.RealScalar then + return X.appArg! + else + return none) + | throwError "can't determine scalar" + + + let goal ← mkAppM ``HasRevFDerivUpdate #[R,g] + let (xs, _, goal') ← forallMetaTelescope (← inferType goal) + let goal := goal.beta xs + + IO.println s!"initial: {← ppExpr goal}" + + let m ← mkFreshExprMVar goal + + let (_,_) ← runTactic m.mvarId! tac.raw + + IO.println s!"result: {← ppExpr (← instantiateMVars goal)}" + + let some goal ← Tactic.DataSynth.isDataSynthGoal? goal + | throwError "invalid goal" + + pure () + + +#check DataArrayN.outerprod + +variable {R : Type} [RealScalar R] [PlainDataType R] + (y : R^[n]) + +#check (HasRevFDerivUpdate R (fun x : R^[n]×R^[n] => x.1.outerprod x.2) _) + rewrite_by + data_synth + +#check SciLean.DataArrayN.outerprod.arg_xy.HasRevFDerivUpdate + + +def Q (q : R^[D]) (l : R^[((D-1)*D)/2]) : R^[D,D] := q.exp.diag + l.lowerTriangular D 1 + +--set_option trace.Meta.Tactic.data_synth true in +-- set_option trace.Meta.isDefEq true in +def_rev_deriv Q in q l by -- + unfold Q + data_synth => + enter[3] + simp -zeta [DataArrayN.diag] + + + + + +#check Simp.Config diff --git a/SciLean/Tactic/DataSynth/Elab.lean b/SciLean/Tactic/DataSynth/Elab.lean index 6b0c7f84..9460c188 100644 --- a/SciLean/Tactic/DataSynth/Elab.lean +++ b/SciLean/Tactic/DataSynth/Elab.lean @@ -11,6 +11,7 @@ open Parser.Tactic in /-- `date_synth` as conv tactic will fill in meta variables in generalized transformation -/ syntax (name:=data_synth_conv) "data_synth" optConfig : conv + /- syntax (name := simp) "simp" optConfig (discharger)? (&" only")? (" [" withoutPosition((simpStar <|> simpErase <|> simpLemma),*,?) "]")? (location)? : tactic -/ @@ -50,3 +51,43 @@ syntax (name:=data_synth_conv) "data_synth" optConfig : conv | none => throwError "`data_synth` failed" | _ => throwUnsupportedSyntax + + + +open Parser.Tactic Conv in +syntax (name:=data_synth_tac) "data_synth" optConfig ("=>" convSeq)? : tactic + +@[tactic data_synth_tac] unsafe def dataSynthTactic : Tactic +| `(tactic| data_synth $cfg:optConfig $[=> $c]?) => do + let m ← getMainGoal + let e ← m.getType + + let cfg ← elabDataSynthConfig cfg + + let some g ← isDataSynthGoal? e + | throwError "{e} is not `data_synth` goal" + + let stateRef : IO.Ref DataSynth.State ← IO.mkRef {} + + let (r?,_) ← dataSynth g |>.run {config := cfg} |>.run stateRef + |>.run (← Simp.mkDefaultMethods).toMethodsRef + |>.run {config := cfg.toConfig, simpTheorems := #[← getSimpTheorems]} + |>.run {} + + match r? with + | some r => + let mut e' := r.getSolvedGoal + if let some c := c then + let (e'',eq) ← elabConvRewrite e' #[] (← `(conv| ($c))) + if ← isDefEq e e'' then + m.assign (← mkEqMP eq r.proof) + setGoals [] + else + if ← isDefEq e e' then + m.assign r.proof + setGoals [] + else + throwError "faield to assign data {e'}" + | none => + throwError "`data_synth` failed" +| _ => throwUnsupportedSyntax diff --git a/SciLean/Tactic/DataSynth/HasRevFDerivUpdate.lean b/SciLean/Tactic/DataSynth/HasRevFDerivUpdate.lean index 90a9fdea..eba5f78e 100644 --- a/SciLean/Tactic/DataSynth/HasRevFDerivUpdate.lean +++ b/SciLean/Tactic/DataSynth/HasRevFDerivUpdate.lean @@ -35,7 +35,6 @@ theorem id_rule : HasRevFDerivUpdate R (fun x : X => x) (fun x => (x, fun dx dx · fun_prop -@[data_synth] theorem const_rule (y : Y) : HasRevFDerivUpdate R (fun x : X => y) (fun x => (y, fun _ dx => dx)) := by constructor · fun_trans @@ -360,6 +359,11 @@ end OverReals #exit +variable (f : X → X) (f') (hf : HasRevFDerivUpdate R f f') + +#check (HasRevFDerivUpdate R (fun x => f x) _) rewrite_by data_synth + +#check (HasRevFDerivUpdate R f _) rewrite_by data_synth set_option trace.Meta.Tactic.data_synth true in @@ -394,7 +398,7 @@ set_option pp.deepTerms.threshold 1000000000000000 #check (HasRevFDerivUpdate R (fun x : R => x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x*x) _ ) rewrite_by - data_synth -normalizeCore + data_synth #check (HasRevFDerivUpdate R (fun x : R×R×R×R => x.1) _) rewrite_by data_synth diff --git a/SciLean/Tactic/DataSynth/Main.lean b/SciLean/Tactic/DataSynth/Main.lean index 2beb1a68..cc70a0f9 100644 --- a/SciLean/Tactic/DataSynth/Main.lean +++ b/SciLean/Tactic/DataSynth/Main.lean @@ -272,6 +272,22 @@ where | _ , [] => mkLambdaFVars fvars fn + +def Goal.assumption? (goal : Goal) : DataSynthM (Option Result) := do + withProfileTrace "assumption?" do + (← getLCtx).findDeclRevM? fun localDecl => do + if localDecl.isImplementationDetail then + return none + else if localDecl.type.isAppOf' goal.dataSynthDecl.name then + let (_,e) ← goal.mkFreshProofGoal + if (← isDefEq e localDecl.type) then + return ← goal.getResultFrom (.fvar localDecl.fvarId) + else + return none + else + return none + + def discharge? (e : Expr) : DataSynthM (Option Expr) := do (← read).discharge e @@ -291,6 +307,11 @@ def synthesizeArgument (x : Expr) : DataSynthM Bool := do if let .some r ← do dataSynth g then x.mvarId!.assignIfDefeq (← mkLambdaFVars ys r.proof) return true + + if let some r ← g.assumption? then + x.mvarId!.assignIfDefeq (← mkLambdaFVars ys r.proof) + return true + return false if b then return true @@ -371,16 +392,16 @@ partial def main (goal : Goal) : DataSynthM (Option Result) := do let thms ← goal.getCandidateTheorems - if thms.size = 0 then - trace[Meta.Tactic.data_synth] "no applicable theorems" - return none - trace[Meta.Tactic.data_synth] "candidates {thms.map (fun t => t.thmName)}" for thm in thms do if let .some r ← goal.tryTheorem? thm then return r + -- try local theorems + if let some r ← goal.assumption? then + return r + return none diff --git a/SciLean/Tactic/DataSynth/Types.lean b/SciLean/Tactic/DataSynth/Types.lean index 8299eb5b..b4ece65a 100644 --- a/SciLean/Tactic/DataSynth/Types.lean +++ b/SciLean/Tactic/DataSynth/Types.lean @@ -2,6 +2,7 @@ import Lean import SciLean.Tactic.LSimp.Main import SciLean.Tactic.DataSynth.Decl import SciLean.Lean.Meta.Uncurry +import SciLean.Lean.Meta.Basic import Mathlib.Logic.Equiv.Defs @@ -174,8 +175,8 @@ def curryLambdaTelescope (f : Expr) (k : Array Expr → Expr → MetaM α) : Met lambdaTelescope f k -def getFunData (f : Expr) : MetaM FunData := - curryLambdaTelescope f fun xs b => do +def getFunData (f : Expr) : MetaM FunData := do + curryLambdaTelescope (← ensureEtaExpanded f) fun xs b => do let data : FunData := { lctx := ← getLCtx insts := ← getLocalInstances