From 845a8b07bfe494b43507ddf9c2bc43695534c563 Mon Sep 17 00:00:00 2001 From: lecopivo Date: Tue, 20 Aug 2024 17:35:23 -0400 Subject: [PATCH] revFDerivProj rules for ArrayType.get --- SciLean.lean | 1 + SciLean/Analysis/Calculus/FwdFDeriv.lean | 64 +- .../Analysis/Calculus/Notation/Gradient.lean | 9 +- SciLean/Analysis/Calculus/RevFDeriv.lean | 11 + SciLean/Analysis/Calculus/RevFDerivProj.lean | 34 +- SciLean/Data/ArrayType/Properties.lean | 45 ++ SciLean/Data/StructType/Algebra.lean | 3 - SciLean/Tactic/FunTrans/Core.lean | 11 +- SciLean/Tactic/FunTrans/Theorems.lean | 9 +- doc/talk/august_umbc_lecture.lean | 556 ++++++++++++++++++ doc/talk/august_umbc_lecture.org | 100 ++++ 11 files changed, 796 insertions(+), 47 deletions(-) create mode 100644 doc/talk/august_umbc_lecture.lean create mode 100644 doc/talk/august_umbc_lecture.org diff --git a/SciLean.lean b/SciLean.lean index fe0a5b21..faa8b022 100644 --- a/SciLean.lean +++ b/SciLean.lean @@ -117,6 +117,7 @@ import SciLean.MeasureTheory.WeakIntegral -- import SciLean.Meta.DerivingOp import SciLean.Meta.GenerateAddGroupHomSimp import SciLean.Meta.GenerateFunProp +import SciLean.Meta.GenerateFunTrans import SciLean.Meta.GenerateLinearMapSimp import SciLean.Meta.Notation.Do import SciLean.Meta.SimpAttr diff --git a/SciLean/Analysis/Calculus/FwdFDeriv.lean b/SciLean/Analysis/Calculus/FwdFDeriv.lean index d49f4d21..c336b58d 100644 --- a/SciLean/Analysis/Calculus/FwdFDeriv.lean +++ b/SciLean/Analysis/Calculus/FwdFDeriv.lean @@ -118,9 +118,33 @@ theorem pi_rule open SciLean --- Prod.mk -----------------------------------v--------------------------------- +-- of linear function ---------------------------------------------------------- -------------------------------------------------------------------------------- +@[fun_trans] +theorem fwdFDeriv_linear + (f : X → Y) (hf : IsContinuousLinearMap K f) : + fwdFDeriv K f + = + fun x dx => (f x, f dx) := by unfold fwdFDeriv; fun_trans + + +-- Prod.mk --------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[fun_trans] +theorem Prod.mk.arg_fstsnd.fwdFDeriv_rule + (g : X → Y) (hg : Differentiable K g) + (f : X → Z) (hf : Differentiable K f) : + fwdFDeriv K (fun x => (g x, f x)) + = + fun x dx => + let ydy := fwdFDeriv K g x dx + let zdz := fwdFDeriv K f x dx + ((ydy.1, zdz.1), (ydy.2, zdz.2)) := by + unfold fwdFDeriv; fun_trans + + @[fun_trans] theorem Prod.mk.arg_fstsnd.fwdFDeriv_rule_at (x : X) (g : X → Y) (hg : DifferentiableAt K g x) @@ -137,6 +161,13 @@ theorem Prod.mk.arg_fstsnd.fwdFDeriv_rule_at (x : X) -- Prod.fst -------------------------------------------------------------------- -------------------------------------------------------------------------------- +@[fun_trans] +theorem Prod.fst.arg_self.fwdFDeriv_rule : + fwdFDeriv K (fun xy : X×Y => xy.1) + = + fun xy dxy => (xy.1, dxy.1) := by + unfold fwdFDeriv; fun_trans + @[fun_trans] theorem Prod.fst.arg_self.fwdFDeriv_rule_at (x : X) (f : X → Y×Z) (hf : DifferentiableAt K f x) : @@ -151,6 +182,14 @@ theorem Prod.fst.arg_self.fwdFDeriv_rule_at (x : X) -- Prod.snd -------------------------------------------------------------------- -------------------------------------------------------------------------------- +@[fun_trans] +theorem Prod.snd.arg_self.fwdFDeriv_rule : + fwdFDeriv K (fun xy : X×Y => xy.2) + = + fun xy dxy => (xy.2, dxy.2) := by + unfold fwdFDeriv; fun_trans + + @[fun_trans] theorem Prod.snd.arg_self.fwdFDeriv_rule_at (x : X) (f : X → Y×Z) (hf : DifferentiableAt K f x) : @@ -239,6 +278,13 @@ theorem HSMul.hSMul.arg_a0a1.fwdFDeriv_rule_at (x : X) -- HDiv.hDiv ------------------------------------------------------------------- -------------------------------------------------------------------------------- +@[fun_trans] +theorem HDiv.hDiv.arg_a0.fwdFDeriv_rule (y : K) : + (fwdFDeriv K fun x => x / y) + = + fun x dx => (x / y, dx / y) := by + unfold fwdFDeriv; fun_trans + @[fun_trans] theorem HDiv.hDiv.arg_a0a1.fwdFDeriv_rule_at (x : X) (f : X → K) (g : X → K) @@ -262,18 +308,12 @@ theorem HDiv.hDiv.arg_a0a1.fwdFDeriv_rule_at (x : X) -------------------------------------------------------------------------------- @[fun_trans] -def HPow.hPow.arg_a0.fwdFDeriv_rule_at (n : Nat) (x : X) - (f : X → K) (hf : DifferentiableAt K f x) : - fwdFDeriv K (fun x => f x ^ n) x +def HPow.hPow.arg_a0.fwdFDeriv_rule (n : Nat) : + fwdFDeriv K (fun x : K => x ^ n) = - fun dx => - let ydy := fwdFDeriv K f x dx - (ydy.1 ^ n, n * ydy.2 * (ydy.1 ^ (n-1))) := by - unfold fwdFDeriv; - funext dx; simp - induction n - case zero => simp - case h _ => simp[pow_succ]; fun_trans; sorry_proof + fun x dx : K => + (x ^ n, n * dx * (x ^ (n-1))) := by + unfold fwdFDeriv; fun_trans -- IndexType.sum ---------------------------------------------------------------- diff --git a/SciLean/Analysis/Calculus/Notation/Gradient.lean b/SciLean/Analysis/Calculus/Notation/Gradient.lean index 151e0ea9..099baa46 100644 --- a/SciLean/Analysis/Calculus/Notation/Gradient.lean +++ b/SciLean/Analysis/Calculus/Notation/Gradient.lean @@ -22,14 +22,15 @@ elab_rules (kind:=gradNotation1) : term let XY ← mkArrow X Y -- Y might also be infered by the function `f` let fExpr ← withoutPostponing <| elabTermEnsuringType f XY false - let sX ← exprToSyntax X let .some (_,Y) := (← inferType fExpr).arrow? | return ← throwUnsupportedSyntax + let sX ← exprToSyntax X + let sK ← exprToSyntax K + let sY ← exprToSyntax Y if (← isDefEq K Y) then - elabTerm (← `(fgradient (X:=$sX) $f $x $xs*)) none false + elabTerm (← `(fgradient (X:=$sX) (K:=$sK) $f $x $xs*)) none false else - elabTerm (← `(adjointFDeriv (X:=$sX) defaultScalar% $f $x $xs*)) none false - + elabTerm (← `(adjointFDeriv (X:=$sX) (Y:=$sY) defaultScalar% $f $x $xs*)) none false | `(∇ $f) => do let K ← elabTerm (← `(defaultScalar%)) none diff --git a/SciLean/Analysis/Calculus/RevFDeriv.lean b/SciLean/Analysis/Calculus/RevFDeriv.lean index 3b8ad46b..df9a48eb 100644 --- a/SciLean/Analysis/Calculus/RevFDeriv.lean +++ b/SciLean/Analysis/Calculus/RevFDeriv.lean @@ -188,6 +188,17 @@ variable {E : ι → Type _} [∀ i, NormedAddCommGroup (E i)] [∀ i, AdjointSpace K (E i)] [∀ i, CompleteSpace (E i)] +-- of linear function ---------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[fun_trans] +theorem revFDeriv_linear + (f : X → Y) (hf : IsContinuousLinearMap K f) : + revFDeriv K f + = + fun x => (f x, adjoint K f) := by unfold revFDeriv; fun_trans + + -- Prod.mk ----------------------------------- --------------------------------- -------------------------------------------------------------------------------- diff --git a/SciLean/Analysis/Calculus/RevFDerivProj.lean b/SciLean/Analysis/Calculus/RevFDerivProj.lean index 35ae9862..7f3b6cb0 100644 --- a/SciLean/Analysis/Calculus/RevFDerivProj.lean +++ b/SciLean/Analysis/Calculus/RevFDerivProj.lean @@ -13,18 +13,18 @@ namespace SciLean set_option deprecated.oldSectionVars true variable - (K I : Type _) [RCLike K] - {X : Type _} [NormedAddCommGroup X] [AdjointSpace K X] - {Y : Type _} [NormedAddCommGroup Y] [AdjointSpace K Y] - {Z : Type _} [NormedAddCommGroup Z] [AdjointSpace K Z] - {W : Type _} [NormedAddCommGroup W] [AdjointSpace K W] - {ι : Type _} [IndexType ι] [DecidableEq ι] - {κ : Type _} [IndexType κ] [DecidableEq κ] - {E : Type _} {EI : I → Type _} + (K I : Type) [RCLike K] + {X : Type} [NormedAddCommGroup X] [AdjointSpace K X] + {Y : Type} [NormedAddCommGroup Y] [AdjointSpace K Y] + {Z : Type} [NormedAddCommGroup Z] [AdjointSpace K Z] + {W : Type} [NormedAddCommGroup W] [AdjointSpace K W] + {ι : Type} [IndexType ι] [DecidableEq ι] + {κ : Type} [IndexType κ] [DecidableEq κ] + {E : Type} {EI : I → Type} [StructType E I EI] [IndexType I] [DecidableEq I] [NormedAddCommGroup E] [AdjointSpace K E] [∀ i, NormedAddCommGroup (EI i)] [∀ i, AdjointSpace K (EI i)] [VecStruct K E I EI] -- todo: define AdjointSpaceStruct - {F J : Type _} {FJ : J → Type _} + {F J : Type} {FJ : J → Type} [StructType F J FJ] [IndexType J] [DecidableEq J] [NormedAddCommGroup F] [AdjointSpace K F] [∀ j, NormedAddCommGroup (FJ j)] [∀ j, AdjointSpace K (FJ j)] [VecStruct K F J FJ] -- todo: define AdjointSpaceStruct @@ -329,6 +329,7 @@ set_option deprecated.oldSectionVars true variable {K : Type} [RCLike K] + {ι : Type} [IndexType ι] [DecidableEq ι] {X : Type} [NormedAddCommGroup X] [AdjointSpace K X] {Y : Type} [NormedAddCommGroup Y] [AdjointSpace K Y] {Z : Type} [NormedAddCommGroup Z] [AdjointSpace K Z] @@ -339,7 +340,7 @@ variable [NormedAddCommGroup Y'] [AdjointSpace K Y'] [∀ i, NormedAddCommGroup (YI i)] [∀ i, AdjointSpace K (YI i)] [VecStruct K Y' Yi YI] [NormedAddCommGroup Z'] [AdjointSpace K Z'] [∀ i, NormedAddCommGroup (ZI i)] [∀ i, AdjointSpace K (ZI i)] [VecStruct K Z' Zi ZI] {W : Type} [NormedAddCommGroup W] [AdjointSpace K W] - {ι : Type} [IndexType ι] + @@ -760,18 +761,13 @@ def HPow.hPow.arg_a0.revFDerivProjUpdate_rule section IndexTypeSum -variable {ι : Type} [IndexType ι] - @[fun_trans] -theorem IndexType.sum.arg_f.revFDerivProj_rule [DecidableEq ι] +theorem IndexType.sum.arg_f.revFDerivProj_rule (f : X → ι → Y') (hf : ∀ i, Differentiable K (fun x => f x i)) : revFDerivProj K Yi (fun x => ∑ i, f x i) = fun x => - -- this is not optimal - -- we should have but right now there is no appropriate StrucLike instance - -- let ydf := revFDerivProj K Yi f x let ydf := revFDerivProjUpdate K (ι×Yi) f x (∑ i, ydf.1 i, fun j dy => @@ -855,10 +851,10 @@ theorem dite.arg_te.revFDerivProjUpdate_rule section InnerProductSpace variable - {R : Type _} [RealScalar R] + {R : Type} [RealScalar R] -- {K : Type _} [Scalar R K] - {X : Type _} [NormedAddCommGroup X] [AdjointSpace R X] - {Y : Type _} [NormedAddCommGroup Y] [AdjointSpace R Y] + {X : Type} [NormedAddCommGroup X] [AdjointSpace R X] + {Y : Type} [NormedAddCommGroup Y] [AdjointSpace R Y] -- Inner ----------------------------------------------------------------------- -------------------------------------------------------------------------------- diff --git a/SciLean/Data/ArrayType/Properties.lean b/SciLean/Data/ArrayType/Properties.lean index 93b8b613..07c9f080 100644 --- a/SciLean/Data/ArrayType/Properties.lean +++ b/SciLean/Data/ArrayType/Properties.lean @@ -3,6 +3,7 @@ import SciLean.Data.ArrayType.Algebra import SciLean.Analysis.Convenient.HasAdjDiff import SciLean.Analysis.AdjointSpace.Adjoint +import SciLean.Analysis.Calculus.RevFDerivProj import SciLean.Meta.GenerateAddGroupHomSimp @@ -250,6 +251,50 @@ theorem ArrayType.ofFn.arg_f.adjoint_rule : end OnAdjointSpace +section OnAdjointSpace + +variable + [NormedAddCommGroup Elem] [AdjointSpace K Elem] [CompleteSpace Elem] + {I : Type} [IndexType I] [DecidableEq I] + {E : I → Type} [∀ i, NormedAddCommGroup (E i)] [∀ i, AdjointSpace K (E i)] + [∀ i, CompleteSpace (E i)] [StructType Elem I E] [VecStruct K Elem I E] + {W : Type} [NormedAddCommGroup W] [AdjointSpace K W] [CompleteSpace W] + + +@[fun_trans] +theorem ArrayType.get.arg_cont.revFDerivProj_rule (i : Idx) + (cont : W → Cont) (hf : Differentiable K cont) : + revFDerivProj K I (fun w => ArrayType.get (cont w) i) + = + fun w : W => + let xi := revFDerivProj K (Idx×I) cont w + (ArrayType.get xi.1 i, fun (j : I) (de : E j) => + xi.2 (i,j) de) := by + unfold revFDerivProj; fun_trans[oneHot] + funext x + fun_trans + funext i de + congr + funext i + split_ifs + · congr; funext i; aesop + · aesop + + +@[fun_trans] +theorem ArrayType.get.arg_cont.revFDerivProjUpdate_rule (i : Idx) + (cont : W → Cont) (hf : Differentiable K cont) : + revFDerivProjUpdate K I (fun w => ArrayType.get (cont w) i) + = + fun w : W => + let xi := revFDerivProjUpdate K (Idx×I) cont w + (ArrayType.get xi.1 i, fun (j : I) (de : E j) dw => + xi.2 (i,j) de dw) := by unfold revFDerivProjUpdate; fun_trans + + +end OnAdjointSpace + + #exit @[fun_trans] diff --git a/SciLean/Data/StructType/Algebra.lean b/SciLean/Data/StructType/Algebra.lean index d70c33f0..2ffc7d0c 100644 --- a/SciLean/Data/StructType/Algebra.lean +++ b/SciLean/Data/StructType/Algebra.lean @@ -107,9 +107,6 @@ instance [∀ i, MetricSpace (EI i)] [∀ j, MetricSpace (FJ j)] (i : I ⊕ J) : dist_self := sorry_proof dist_comm := sorry_proof dist_triangle := sorry_proof - edist := match i with - | .inl _ => PseudoMetricSpace.edist - | .inr _ => PseudoMetricSpace.edist edist_dist := sorry_proof toUniformSpace := by infer_instance uniformity_dist := sorry_proof diff --git a/SciLean/Tactic/FunTrans/Core.lean b/SciLean/Tactic/FunTrans/Core.lean index e5b4bd8f..e7f60b88 100644 --- a/SciLean/Tactic/FunTrans/Core.lean +++ b/SciLean/Tactic/FunTrans/Core.lean @@ -234,7 +234,7 @@ def applyApplyRule (funTransDecl : FunTransDecl) (e : Expr) : SimpM (Option Simp return none -def applyPiRule (funTransDecl : FunTransDecl) (e : Expr) : SimpM (Option Simp.Result) := do +def applyPiRule (funTransDecl : FunTransDecl) (e f : Expr) : SimpM (Option Simp.Result) := do let thms ← getLambdaTheorems funTransDecl.funTransName .pi e.getAppNumArgs if thms.size = 0 then @@ -242,7 +242,8 @@ def applyPiRule (funTransDecl : FunTransDecl) (e : Expr) : SimpM (Option Simp.Re return none for thm in thms do - if let .some r ← tryTheorem? e (.decl thm.thmName) then + let .pi id_f := thm.thmArgs | continue + if let .some r ← tryTheoremWithHint e (.decl thm.thmName) #[(id_f, f)] then return r return none @@ -253,7 +254,7 @@ def applyMorTheorems (funTransDecl : FunTransDecl) (e : Expr) (fData : FunProp.F match ← fData.isMorApplication with | .none => return none | .underApplied => - applyPiRule funTransDecl e + applyPiRule funTransDecl e (← fData.toExpr) | .overApplied => let .some (f,g) ← fData.peeloffArgDecomposition | return none applyCompRule funTransDecl e f g @@ -359,7 +360,7 @@ def tryTheorems (funTransDecl : FunTransDecl) (e : Expr) (fData : FunProp.Functi return none | .gt => trace[Meta.Tactic.fun_trans] s!"adding argument to later use {← ppOrigin' thm.thmOrigin}" - if let .some r ← applyPiRule funTransDecl e then + if let .some r ← applyPiRule funTransDecl e (← fData.toExpr) then return r continue | .eq => @@ -574,7 +575,7 @@ partial def funTrans (e : Expr) : SimpM Simp.Step := do | .lam f => trace[Meta.Tactic.fun_trans.step] "lam case on {← ppExpr f}" let e := e.setArg funTransDecl.funArgId f -- update e with reduced f - toStep <| applyPiRule funTransDecl e + toStep <| applyPiRule funTransDecl e f | .data fData => let e := e.setArg funTransDecl.funArgId (← fData.toExpr) -- update e with reduced f diff --git a/SciLean/Tactic/FunTrans/Theorems.lean b/SciLean/Tactic/FunTrans/Theorems.lean index b5b95a1c..82f9ce52 100644 --- a/SciLean/Tactic/FunTrans/Theorems.lean +++ b/SciLean/Tactic/FunTrans/Theorems.lean @@ -39,7 +39,7 @@ inductive LambdaTheoremArgs | letE (fArgId gArgId : Nat) /-- Pi theorem e.g. `fderiv ℝ fun x y => f x y = ...` -/ - | pi + | pi (fArgId : Nat) deriving Inhabited, BEq, Repr, Hashable /-- Tag for one of the 5 basic lambda theorems -/ @@ -66,7 +66,7 @@ def LambdaTheoremArgs.type (t : LambdaTheoremArgs) : LambdaTheoremType := | .comp .. => .comp | .letE .. => .letE | .apply => .apply - | .pi => .pi + | .pi .. => .pi /-- Decides whether `f` is a function corresponding to one of the lambda theorems. -/ def detectLambdaTheoremArgs (f : Expr) (ctxVars : Array Expr) : @@ -91,8 +91,9 @@ def detectLambdaTheoremArgs (f : Expr) (ctxVars : Array Expr) : let .some argId_f := ctxVars.findIdx? (fun x => x == (.fvar fId)) | return none let .some argId_g := ctxVars.findIdx? (fun x => x == (.fvar gId)) | return none return .some <| .letE argId_f argId_g - | .lam _ _ (.app (.app (.fvar _) (.bvar 1)) (.bvar 0)) _ => - return .some .pi + | .lam _ _ (.app (.app (.fvar fId) (.bvar 1)) (.bvar 0)) _ => + let .some argId_f := ctxVars.findIdx? (fun x => x == (.fvar fId)) | return none + return .some <| .pi argId_f | _ => return none | _ => return none diff --git a/doc/talk/august_umbc_lecture.lean b/doc/talk/august_umbc_lecture.lean new file mode 100644 index 00000000..e0e8ac03 --- /dev/null +++ b/doc/talk/august_umbc_lecture.lean @@ -0,0 +1,556 @@ +import SciLean + +open SciLean Scalar RealScalar + + + +@[simp ↓ high, simp_core ↓ high] +theorem hihih + {R : Type} [RCLike R] + {X : Type} [NormedAddCommGroup X] [AdjointSpace R X] [CompleteSpace X] [PlainDataType X] + {Y : Type} [NormedAddCommGroup Y] [AdjointSpace R Y] [CompleteSpace Y] + {I : Type} [IndexType I] + (f : DataArrayN X I → Y) : + revFDeriv R f + = + fun x => + let ydf := revFDerivProj R Unit f x + (ydf.1, ydf.2 ()) := by unfold revFDerivProj revFDeriv; simp + + +@[fun_trans] +theorem hohoe {n:Nat} (i : Fin n) : + revFDerivProjUpdate Float Unit (fun (x : Float^[n]) => x[i]) + = + fun x => + (ArrayType.get x i, + fun _ dxi y => ArrayType.modify y i (· + dxi)) := sorry + +#check fun (i : Fin 10 × Unit) => + (?inst : (i : Unit) → NormedAddCommGroup ((?EI : Unit → Type) i)) + i.2 + + + +set_default_scalar Float +variable (u : Float^[10]) +-- set_option trace.Meta.Tactic.fun_trans true in +set_option trace.Meta.Tactic.fun_trans.unify true in +-- set_option trace.Meta.Tactic.simp.rewrite true in +-- set_option pp.funBinderTypes true in +-- set_option pp.mvars.withType true in + +set_option trace.Meta.isDefEq true in + +#check (∇ (u':=u), ∑ i, u'[i]^2) rewrite_by + unfold fgradient + autodiff + simp + pattern (revFDerivProjUpdate _ _ _) + rw[SciLean.revFDerivProjUpdate.pi_rule (K:=Float) (I:=Unit) (f:=fun (x : Float^[10]) i => x[i]^2) (hf:=by fun_prop)] + + + + +-- =?= fun (i : Fin 10 × Unit) => NonUnitalNormedRing.toNormedAddCommGroup + +#exit +-- __ __ _ _ _ _ _ +-- \ \ / /__ _ _| |_(_)_ _ __ _ __ __ _(_) |_| |_ +-- \ \/\/ / _ \ '_| / / | ' \/ _` | \ V V / | _| ' \ +-- \_/\_/\___/_| |_\_\_|_||_\__, | \_/\_/|_|\__|_||_| +-- |___/ +-- _ +-- /_\ _ _ _ _ __ _ _ _ ___ +-- / _ \| '_| '_/ _` | || (_-< +-- /_/ \_\_| |_| \__,_|\_, /__/ +-- |__/ +section WorkingWithArrays + + +-- _ _ _ _ _ +-- | | (_) |_ ___ _ _ __ _| | /_\ _ _ _ _ __ _ _ _ ___ +-- | |__| | _/ -_) '_/ _` | | / _ \| '_| '_/ _` | || (_-< +-- |____|_|\__\___|_| \__,_|_| /_/ \_\_| |_| \__,_|\_, /__/ +-- |__/ + +-- List +#check [1.0, 2.0] +-- Array +#check #[1.0, 2.0] + +-- Mathlib's vector +#check ![1.0, 2.0] +-- Mathlib's matrix +#check !![1.0, 2.0; 3.0, 4.0] + +-- SciLean's vector +#check ⊞[1.0, 2.0] +#eval ⊞[1.0, 2.0] + +-- SciLean's matrix +#check ⊞[1.0, 2.0; 3.0, 4.0] +#eval ⊞[1.0, 2.0; 3.0, 4.0] + + +-- ___ _ _ _ +-- | __| |___ _ __ ___ _ _| |_ /_\ __ __ ___ ______ +-- | _|| / -_) ' \/ -_) ' \ _| / _ \/ _/ _/ -_|_-<_-< +-- |___|_\___|_|_|_\___|_||_\__| /_/ \_\__\__\___/__/__/ + +def u := ⊞[1.0, 2.0] +def A := ⊞[1.0, 2.0; 3.0, 4.0] + +-- element access +#eval u[1] +#eval A[0,1] + +-- automatic index type inference +#check fun i => u[i] +#check fun i j => A[i,j] +#check fun ij => A[ij] +#check fun (i,j) => A[i,j] + +#eval ∑ i, u[i] +#eval ∑ i j, A[i,j] +#eval ∏ ij, A[ij] + +-- lambda notation for arrays +#check ⊞ (i : Fin 4) => i.1.toFloat +#eval ⊞ (i : Fin 4) => i.1.toFloat + +-- lambda notation for matrices +#check ⊞ (i j : Fin 4) => i.1.toFloat + 4 * j.1.toFloat +#eval ⊞ (i j : Fin 4) => i.1.toFloat + 4 * j.1.toFloat + +-- beware nested lambda notation creates vector of vectors +#check ⊞ (i : Fin 4) => ⊞ (j : Fin 4) => i.1.toFloat + 4 * j.1.toFloat + + +-- imperative code to create vector +def array1 := Id.run do + let mut x : Float^[4] := 0 + for i in fullRange (Fin 4) do + x[i] := i.1.toFloat + return x + +#check array1 +#eval array1 + +-- using standard range notation is cumbersome right now +def array2 := Id.run do + let mut x : Float^[4] := 0 + for h : i in [0:4] do + let i : Fin 4 := ⟨i, by simp_all [Membership.mem]⟩ + x[i] := i.1.toFloat + return x + +-- imperative code to create matrix +def matrix1 := Id.run do + let mut A : Float^[4,4] := 0 + for (i,j) in fullRange (Fin 4 × Fin 4) do + A[i,j] := i.1.toFloat + 4 * j.1.toFloat + return A + +#check matrix1 +#eval matrix1 + + +-- dot product for vectors +def dot {n : Nat} (x y : Float^[n]) : Float := ∑ i, x[i] * y[i] +def matMul {n m : Nat} (A : Float^[n,m]) (x : Float^[m]) : Float^[n] := ⊞ i => ∑ j, A[i,j] * x[j] + +#eval dot u u +#eval matMul A u + +-- dimension mismatch +#check_failure dot u A +#check_failure matMul A A + +-- not general enough in the index type +#check_failure dot A A + + +-- ___ _ ___ _ +-- / __|___ _ _ ___ _ _ __ _| | |_ _|_ _ __| |_____ __ +-- | (_ / -_) ' \/ -_) '_/ _` | | | || ' \/ _` / -_) \ / +-- \___\___|_||_\___|_| \__,_|_| |___|_||_\__,_\___/_\_\ + +-- explain syntactic sugar Float^[n,m] +example : Float^[2] = DataArrayN Float (Fin 2) := by rfl +example : Float^[2] = Float^[Fin 2] := by rfl + +#check Float^[Fin 2 × Fin 2] +#check Float^[Fin 2 ⊕ Fin 3] +#check Float^[(Fin 2 ⊕ Fin 3) × Set.Icc (-2:ℤ) (2:ℤ)] + +example : IndexType (Fin 2 × Fin 2) := by infer_instance +example : IndexType (Fin 2 ⊕ Fin 3 × Set.Icc (-2:ℤ) (2:ℤ)) := by infer_instance + +-- 2 * 5^0 + 3 * 5^1 + 1 * 5^2 +#eval IndexType.toFin ((2, 3, 1) : Fin 5 × Fin 5 × Fin 5) + +-- example of using more complicated indices +#eval ⊞ (i : Set.Icc (-2:ℤ) (2:ℤ)) => Float.ofInt i.1 + +#eval ⊞ (i : Fin 2 ⊕ Fin 3) => + match i with + | .inl i => i.1.toFloat + | .inr j => 100 * j.1.toFloat + +#check ⊞ (i : (Fin 2 ⊕ Fin 3) × Set.Icc (-2:ℤ) (2:ℤ) × Fin 2) => + (IndexType.toFin i).1.toFloat + + +-- generalized dot +variable {I J : Type} [IndexType I] [IndexType J] +def dot' (x y : Float^[I]) : Float := ∑ i, x[i] * y[i] + +#eval dot' u u +#eval dot' A A + +-- eneralized matMul +def matMul' (A : Float^[I,J]) (x : Float^[J]) : Float^[I] := ⊞ i => ∑ j, A[i,j] * x[j] + +#eval matMul' A u + +-- initialize rank 3 tensor T +def T := ⊞ (i j k : Fin 2) => i.1.toFloat + 2 * j.1.toFloat + 4 * k.1.toFloat + +-- test dot on T +#eval dot' T T + +-- test matMul on T +-- it works because (T : Float^[Fin 2 × (Fin 2 × Fin 2)]) (A : Float^[Fin 2 × Fin 2]) +-- thuse we have (I := Fin 2) (J := Fin 2 × Fin 2) +#eval matMul' T A + + +-- imprative style implementation of matMul +def matMul'' (A : Float^[I,J]) (x : Float^[J]) : Float^[I] := Id.run do + let mut y : Float^[I] := 0 + for i in fullRange I do + for j in fullRange J do + y[i] += A[i,j] * x[j] + return y + +#eval matMul'' T A + + +end WorkingWithArrays + +-- _ _ _ _ +-- /_\ _ _| |_ ___ _ __ __ _| |_(_)__ +-- / _ \ || | _/ _ \ ' \/ _` | _| / _| +-- /_/ \_\_,_|\__\___/_|_|_\__,_|\__|_\__| +-- ___ _ __ __ _ _ _ _ +-- | \(_)/ _|/ _|___ _ _ ___ _ _| |_(_)__ _| |_(_)___ _ _ +-- | |) | | _| _/ -_) '_/ -_) ' \ _| / _` | _| / _ \ ' \ +-- |___/|_|_| |_| \___|_| \___|_||_\__|_\__,_|\__|_\___/_||_| + +section AutomaticDifferentiation + + +-- __ _ _ +-- / _|__| |___ _ _(_)_ __ +-- | _/ _` / -_) '_| \ V / +-- |_| \__,_\___|_| |_|\_/ + +section FDeriv + +variable (x₀ : ℝ) + +-- mathlib's fderiv +#check (fderiv ℝ (fun x : ℝ => x*x*x) x₀ 1) rewrite_by autodiff + +set_default_scalar ℝ + +-- nice notation for derivative +#check (∂ x : ℝ, (x*x*x + x*x)) rewrite_by autodiff + +#check (∂ (x:=x₀), (x*x*x + x*x)) rewrite_by autodiff + +#check (∂ (x:=x₀), (sin x + exp x + cos x)) rewrite_by autodiff + + +-- differentiating more complicated expressions +#check (∂ (x:=x₀), ∑ (i : Fin 10), sin (x^i.1)) rewrite_by autodiff + +-- differentiating w.r.t to vector or matrix arguments +variable {n : Nat} (i : Fin n) +variable (A dA : Fin n → Fin n → ℝ) (u du v dv : Fin n → ℝ) + +#check (∂ (u':=u;du), ∑ j, A i j * u' j) rewrite_by autodiff +#check (∂ (A':=A;dA), ∑ j, A' i j * u j) rewrite_by autodiff +#check (∂ ((A',u'):=(A,u);(dA,du)), ∑ j, A' i j * u' j) rewrite_by autodiff + +-- right now this is really slow for some reason :( +-- #check (∂ (A':=A;dA), (fun i => ∑ j, A' i j * ∑ k, A' j k * u k)) rewrite_by autodiff + + +-- differentiationg w.r.t arrays +set_default_scalar Float +variable (A dA : Float^[n,n]) (u du v dv : Float^[n]) + +#check (∂ (u':=u;du), (⊞ i => ∑ j, A[i,j] * u'[j] )) rewrite_by autodiff +#check (∂ (A':=A;dA), (⊞ i => ∑ j, A'[i,j] * u[j] )) rewrite_by autodiff + +-- right now this is extremely slow for some reason :( +-- #check (∂ ((A',u'):=(A,u);(dA,du)), (⊞ i => ∑ j, A'[i,j] * u'[j] )) rewrite_by autodiff +-- #check (∂ (A':=A;dA), (⊞ i => ∑ j, A'[i,j] * ∑ k, A'[j,k]*u[k] )) rewrite_by autodiff + +set_default_scalar ℝ + +-- differentiating division, log +set_option trace.Meta.Tactic.fun_trans true in +#check (∂ (x:=x₀), (exp x / x)) rewrite_by autodiff + +variable (h : x₀ ≠ 0) +#check (∂ (x:=x₀), (exp x / x)) rewrite_by assuming autodiff + +#check (∂ (x:=x₀), (exp x / (x*x))) rewrite_by assuming (h : x₀ ≠ 0) autodiff (disch:=assumption) + +#check (∂ (x:=x₀), exp x / (x^2 - 3*x + 1)) rewrite_by assuming (h : ∀ x : ℝ, x≠0) autodiff (disch:=apply h) + +#check (∂ (x:=x₀), exp x / (x^2 - 3*x + 1)) rewrite_by assuming (h : x₀ ^ 2 - 3 * x₀ + 1 ≠ 0) autodiff (disch:=apply h) + +#check (∂ (x:=x₀), log (x^n)) rewrite_by autodiff (disch:=aesop) + + +end FDeriv +-- _ _ _ +-- __ _ _ _ __ _ __| (_)___ _ _| |_ +-- / _` | '_/ _` / _` | / -_) ' \ _| +-- \__, |_| \__,_\__,_|_\___|_||_\__| +-- |___/ + +section Gradient +set_default_scalar ℝ +variable (x : ℝ) + +#check (∇ (x':=x), x'*x'*x') rewrite_by autodiff + +-- Taking gradient w.r.t. to matrix and vector +variable {n : Nat} (A dA : Fin n → Fin n → ℝ) (u du v dv : Fin n → ℝ) (i : Fin n) +-- ∇_u ‖u‖₂² = 2•u +#check (∇ (u':=u), ‖u'‖₂²) rewrite_by autodiff +-- ∇_u ⟪u,v⟫ = v +#check (∇ (u':=u), ⟪u',v⟫) rewrite_by autodiff + + +-- ∇_u u[i] = [0,...,1,...,0] +#check (∇ (u':=u), u' i) rewrite_by autodiff + +-- accessing elements does not work properly yet +-- ∇_u (∑ i, uᵢ²) = ∇_u ‖u‖₂² = 2 • u +#check (∇ (u':=u), ∑ i, u' i^2) rewrite_by autodiff + + +set_default_scalar Float +variable (A dA : Float^[3,3]) (u du : Float^[3]) + +-- ∇_u u[i] = [0,...,1,...,0] +#check (∇ (u':=u), u'[1]) rewrite_by autodiff +#eval (∇ (u':=u), u'[1]) rewrite_by autodiff + + +-- ∇_A trace(A) = I +#check (∇ (A':=A), ∑ i, A'[i,i]) rewrite_by autodiff +#eval (∇ (A':=A), ∑ i, A'[i,i]) rewrite_by autodiff + + + +end Gradient + +-- ___ _ _ +-- | __|__ _ ___ __ ____ _ _ _ __| | __ _ _ _ __| | +-- | _/ _ \ '_\ V V / _` | '_/ _` | / _` | ' \/ _` | +-- |_|\___/_| \_/\_/\__,_|_| \__,_| \__,_|_||_\__,_| +-- ___ __ __ _ +-- | _ \_____ _____ _ _ ___ ___ | \/ |___ __| |___ +-- | / -_) V / -_) '_(_- (x':=x;dx), f x' := by rfl +example : revFDeriv ℝ f x = <∂ (x':=x), f x' := by rfl + +variable (g : X → ℝ) +-- gradient is just revFDeriv +example : ∇ (x':=x), g x' = (<∂ (x':=x), g x').2 1 := by rfl + + +#check (∂> (x : ℝ), x^3) rewrite_by autodiff +#check (<∂ (x : Fin 10 → ℝ), ‖x‖₂²) rewrite_by autodiff + + +#check (∂> (x : ℝ), + let t1 := x^2 + let t2 := t1^3 + let t3 := t2^4 + let t4 := t3^5 + t4) rewrite_by autodiff + +#check (<∂ (x : ℝ), + let t1 := x^2 + let t2 := t1^3 + let t3 := t2^4 + let t4 := t3^5 + t4) rewrite_by autodiff + + +set_default_scalar Float + +#check (∂> (x : Float^[10]), + let mean := ∑ i, x[i] + let var := (1.0/9.0) * ∑ i, (x[i] - mean)^2 + var) rewrite_by autodiff + +#check (<∂ (x : Float^[10]), + let mean := ∑ i, x[i] + let var := (1.0/9.0) * ∑ i, (x[i] - mean)^2 + var) rewrite_by autodiff + +set_default_scalar ℝ + + +-- _ _ ___ __ _ _ +-- | | | |___ ___ _ _ | \ ___ / _(_)_ _ ___ __| | +-- | |_| (_- foo by unfold foo; autodiff +def_fun_trans : <∂ foo by unfold foo; autodiff + +-- `∇ foo`, `deriv foo` and `∂ foo` do not work as they are not merked as function transformations +-- def_fun_trans : ∇ foo by unfold foo; autodiff +-- def_fun_trans : deriv foo by unfold foo; autodiff +-- def_fun_trans : ∂ foo by unfold foo; autodiff + +-- check that AD works now +#check (∂ x, foo x) rewrite_by autodiff + +-- however nesting does not work, why? +set_option trace.Meta.Tactic.fun_trans true in +#check (∂ x, foo (foo x)) rewrite_by autodiff + +-- define new function proposition +def_fun_prop with_transitive : Differentiable ℝ foo by unfold foo; fun_prop + +-- check new theorem and demonstrate `with_transitive` keyword +#check foo.arg_x.Differentiable_rule + + + +-- ___ _ _ +-- | __| _ _ _ __| |_(_)___ _ _ +-- | _| || | ' \/ _| _| / _ \ ' \ +-- |_| \_,_|_||_\__|\__|_\___/_||_| +-- _____ __ _ _ +-- |_ _| _ __ _ _ _ ___/ _|___ _ _ _ __ __ _| |_(_)___ _ _ +-- | || '_/ _` | ' \(_-< _/ _ \ '_| ' \/ _` | _| / _ \ ' \ +-- |_||_| \__,_|_||_/__/_| \___/_| |_|_|_\__,_|\__|_\___/_||_| + + +-- SciLean provides general tactic `fun_trans` for function transformation +-- inspired by JAX + +-- `autodiff` is just `fun_trans` with custom settings + +-- define custom derivative +@[fun_trans] +noncomputable +def myderiv (f : ℝ → ℝ) (x : ℝ) := fderiv ℝ f x 1 + + +-- basic lambda calculus rules + +-- identity rule: d/dx x = 1 +@[fun_trans] +theorem id_rule : myderiv (fun x : ℝ => x) = fun x => 1 := by unfold myderiv; fun_trans + +-- constant rule: d/dx constant = 0 +@[fun_trans] +theorem const_rule (y : ℝ) : myderiv (fun x : ℝ => y) = fun x => 0 := by unfold myderiv; fun_trans + +-- chain/composition rule: (f(g(x)))' = f'(g(x))*g'(x) +@[fun_trans] +theorem comp_rule (f g : ℝ → ℝ) (hf : Differentiable ℝ f) (hg : Differentiable ℝ g) : + myderiv (fun x => f (g x)) + = + fun x => myderiv f (g x) * myderiv g x := by unfold myderiv; fun_trans[mul_comm] + + +-- derivative rules for operations + +-- addition rule: (f + g)' = f' + g' +@[fun_trans] +theorem add_rule (f g : ℝ → ℝ) (hf : Differentiable ℝ f) (hg : Differentiable ℝ g) : + myderiv (fun x => f x + g x) + = + fun x => myderiv f x + myderiv g x := by unfold myderiv; fun_trans + + +-- multiplication/Leibnitz rule: (f*g)' = f'*g + f*g' +@[fun_trans] +theorem mul_rule (f g : ℝ → ℝ) (hf : Differentiable ℝ f) (hg : Differentiable ℝ g) : + myderiv (fun x => f x * g x) + = + fun x => myderiv f x * g x + f x * myderiv g x := by unfold myderiv; fun_trans[mul_comm,add_comm] + + +-- test `myderiv` with `fun_trans` +#check (myderiv (fun x : ℝ => x*x*x*x + x*x)) rewrite_by fun_trans + + + + + + + -- conv => + -- pattern revFDerivProjUpdate _ _ _ + -- rw [revFDerivProjUpdate.pi_rule (K:=Float) (ι:=Fin 10) (I:=Unit) (f:=fun (x :Float^[10]) i => x[i]^2) (hf:=by fun_prop)] + -- autodiff + + -- autodiff + -- simp + -- enter [1,dx,i] + -- + -- autodiff + + +end AutomaticDifferentiation diff --git a/doc/talk/august_umbc_lecture.org b/doc/talk/august_umbc_lecture.org new file mode 100644 index 00000000..ac2de446 --- /dev/null +++ b/doc/talk/august_umbc_lecture.org @@ -0,0 +1,100 @@ + +* SciLean + +** library for scientific computing + +** motivation - mix of Lean, Mathematica, Julia, JAX + +** priorities: usability, performance, ...., formal correctness + - SciLean is not formalization project (at least not primarily) + +* Harmonic oscillator example + + +* Talk overview + +** Working with arrays + +** symbolic and automatic differentiation + +*** fderiv and how to use autodiff + + ∂ x', f x' = fderiv R f + + + + + ∂ (x':=x), f x' = fderiv R f x + + ∂ (x':=x;dx), f x' = fderiv R f x dx + + For scalar arguments the notation automatically inserts `dx = 1` + + + +*** gradient, start using notation + + - fgradient vs mathlib's gradient - RxR and Fin n -> R is not InnerProductSpace + +*** forward and reverse mode AD + + ∂> x':=x;dx, f x dx = fwdFDeriv R f x dx = (f x, fderiv R f x dx) + + + revFDeriv R f x = (f x, adjoint R (fderiv R f x)) + + +*** Working with user defined function + + def foo (x : R) := 3*x^3 + x^2 + + def_fun_trans : ∂ x, foo x by unfold foo; autodiff + def_fun_trans : ∂> x, foo x by unfold foo; autodiff + def_fun_trans : <∂ x, foo x by unfold foo; autodiff + + #print foo.arg_x.fderiv + #check foo.arg_x.fderiv_rule + + #check (∂ x, foo x) rewrite_by autodiff + +*** general function transformation + - tactics: + fun_prop - proving function properties like Continuous, Differentiable + - part of mathlib + fun_trans - function transformation tactic to compute derivatives, adjoint etc. + - part of scilean + + - have a look at documentation of fun_prop + + @[fun_trans] + def myderiv (f : R -> R) (x : R) : R := fderiv R f x 1 + + @[fun_trans] + theorem id_rule : myderiv (fun x : R => x) = fun x => 1 := sorry + @[fun_trans] + theorem const_rule (y : R) : myderiv (fun x : R => y) = fun x => 0 := sorry + -- (f(g(x))' = f'(g(x))*g'(x) + @[fun_trans] + theorem comp_rule (f g : R -> R) (hf : Differentiable R f) (hg : Differentiable R g) : + myderiv (fun x => f (g x)) + = + fun x => + let y := g x + let dy := myderiv g x + myderiv f (g x) * myderiv g x := sorry + + + variable (f g : R -> R) (hf : Differentiable R f) (hg : Differentiable R g) + + -- (f + g)' = f' + g' + @[fun_trans] + theorem add_rule : myderiv (fun x => f x + g x) = fun x => myderiv f x + myderiv g x := sorry + + -- (f * g)' = f'*g + f*g' + @[fun_trans] + theorem mul_rule : myderiv (fun x => f x * g x) = fun x => myderiv f x * g x + f x * myderiv g x := sorry + + + #check (myderiv (fun x : R => x*x*x + x*x)) rewrite_by fun_trans + +** probabilistic programming