From 4af1c952e466e513f8edfc0378d7378bcdd074d6 Mon Sep 17 00:00:00 2001 From: lecopivo Date: Sat, 30 Mar 2024 19:32:54 +0100 Subject: [PATCH] some refactoring of distributions --- SciLean/Core/Distribution/Basic.lean | 34 ++++++- SciLean/Core/Distribution/Eval.lean | 23 +++-- .../Distribution/ParametricDistribDeriv.lean | 92 +++++++++++++++---- .../ParametricDistribRevDeriv.lean | 67 +++++++++++++- SciLean/Core/Distribution/SimpleExamples.lean | 6 -- .../Core/Distribution/SimpleExamples2D.lean | 23 +++-- SciLean/Core/Distribution/SurfaceDirac.lean | 2 +- 7 files changed, 192 insertions(+), 55 deletions(-) diff --git a/SciLean/Core/Distribution/Basic.lean b/SciLean/Core/Distribution/Basic.lean index 310495d7..8649d0df 100644 --- a/SciLean/Core/Distribution/Basic.lean +++ b/SciLean/Core/Distribution/Basic.lean @@ -151,8 +151,7 @@ simproc_decl Distribution.mk_extAction_simproc (Distribution.extAction (Distribu -- seqRight_eq := by intros; rfl -- pure_seq := by intros; rfl -def vecDirac (x : X) (y : Y) : π’Ÿ'(X,Y) := ⟨fun Ο† ⊸ Ο† x β€’ y⟩ -abbrev dirac (x : X) : π’Ÿ' X := vecDirac x 1 +def dirac (x : X) : π’Ÿ' X := ⟨fun Ο† ⊸ Ο† x⟩ open Notation noncomputable @@ -170,7 +169,7 @@ def Distribution.bind' (x' : π’Ÿ'(X,U)) (f : X β†’ π’Ÿ'(Y,V)) (L : U β†’ V β†’ ---------------------------------------------------------------------------------------------------- @[simp, ftrans_simp] -theorem action_vecDirac (x : X) (y : Y) (Ο† : π’Ÿ X) : βŸͺ(vecDirac x y), Ο†βŸ« = Ο† x β€’ y := by simp[dirac,vecDirac] +theorem action_dirac (x : X) (Ο† : π’Ÿ X) : βŸͺdirac x, Ο†βŸ« = Ο† x := by simp[dirac] @[simp, ftrans_simp] theorem action_bind (x : π’Ÿ'(X,Z)) (f : X β†’ π’Ÿ' Y) (Ο† : π’Ÿ Y) : @@ -411,6 +410,35 @@ abbrev Distribution.postRestrict (T : π’Ÿ'(X,π’Ÿ'(Y,Z))) (A : X β†’ Set Y) : sorry_proof⟩⟩ +@[simp, ftrans_simp] +theorem postComp_id (u : π’Ÿ'(X,Y)) : + (u.postComp (fun y => y)) = u := sorry_proof + +@[simp, ftrans_simp] +theorem postComp_comp (x : π’Ÿ'(X,U)) (g : U β†’ V) (f : V β†’ W) : + (x.postComp g).postComp f + = + x.postComp (fun u => f (g u)) := sorry_proof + +@[simp, ftrans_simp] +theorem postComp_assoc (x : π’Ÿ'(X,U)) (y : U β†’ π’Ÿ'(Y,V)) (f : V β†’ W) (Ο† : Y β†’ R) : + (x.postComp y).postComp (fun T => T.postComp f) + = + (x.postComp (fun u => (y u).postComp f)) := sorry_proof + +@[action_push] +theorem postComp_extAction (x : π’Ÿ'(X,U)) (y : U β†’ V) (Ο† : X β†’ R) : + (x.postComp y).extAction Ο† + = + y (x.extAction Ο†) := sorry_proof + +@[action_push] +theorem postComp_restrict_extAction (x : π’Ÿ'(X,U)) (y : U β†’ V) (A : Set X) (Ο† : X β†’ R) : + ((x.postComp y).restrict A).extAction Ο† + = + y ((x.restrict A).extAction Ο†) := sorry_proof + + @[simp, ftrans_simp, action_push] theorem Distribution.zero_postExtAction (Ο† : Y β†’ R) : (0 : π’Ÿ'(X,π’Ÿ'(Y,Z))).postExtAction Ο† = 0 := by sorry_proof diff --git a/SciLean/Core/Distribution/Eval.lean b/SciLean/Core/Distribution/Eval.lean index abfc1e7c..6f8d6c48 100644 --- a/SciLean/Core/Distribution/Eval.lean +++ b/SciLean/Core/Distribution/Eval.lean @@ -11,6 +11,9 @@ variable {X} [TopologicalSpace X] [space : TCOr (Vec R X) (DiscreteTopology X)] {Y} [Vec R Y] {Z} [Vec R Z] + {U} [Vec R U] + {V} [Vec R V] + {W} [Vec R W] set_default_scalar R @@ -21,22 +24,24 @@ theorem action_extAction (T : π’Ÿ' X) (Ο† : π’Ÿ X) : T.action Ο† = T.extAction Ο† := sorry_proof @[action_push] -theorem extAction_vecDirac (x : X) (y : Y) (Ο† : X β†’ R) : - (vecDirac x y).extAction Ο† +theorem extAction_vecDirac (x : X) (Ο† : X β†’ R) : + (dirac x).extAction Ο† = - Ο† x β€’ y := sorry_proof + Ο† x := sorry_proof @[action_push] -theorem extAction_restrict_vecDirac (x : X) (y : Y) (A : Set X) (Ο† : X β†’ R) : - ((vecDirac x y).restrict A).extAction Ο† +theorem extAction_restrict_vecDirac (x : X) (A : Set X) (Ο† : X β†’ R) : + ((dirac x).restrict A).extAction Ο† = - if x ∈ A then Ο† x β€’ y else 0 := sorry_proof + if x ∈ A then Ο† x else 0 := sorry_proof + + -- x.postComp (fun u => (y u).extAction Ο†) := by sorry_proof @[action_push] -theorem postExtAction_vecDirac (x : X) (y : π’Ÿ'(Y,Z)) (Ο† : Y β†’ R) : - (vecDirac x y).postExtAction Ο† +theorem postExtAction_postComp (x : π’Ÿ'(X,U)) (y : U β†’ π’Ÿ'(Y,Z)) (Ο† : Y β†’ R) : + (x.postComp y).postExtAction Ο† = - vecDirac x (y.extAction Ο†) := sorry_proof + x.postComp (fun u => (y u).extAction Ο†) := by sorry_proof variable [MeasureSpace X] diff --git a/SciLean/Core/Distribution/ParametricDistribDeriv.lean b/SciLean/Core/Distribution/ParametricDistribDeriv.lean index 61277aee..11b7f062 100644 --- a/SciLean/Core/Distribution/ParametricDistribDeriv.lean +++ b/SciLean/Core/Distribution/ParametricDistribDeriv.lean @@ -13,16 +13,17 @@ open Distribution variable {R} [RealScalar R] {W} [Vec R W] - {X} [Vec R X] + {X} [Vec R X] [MeasureSpace X] {Y} [Vec R Y] [Module ℝ Y] {Z} [Vec R Z] [Module ℝ Z] {U} [Vec R U] -- [Module ℝ U] + set_default_scalar R noncomputable -def vecDiracDeriv (x dx : X) (y dy : Y) : π’Ÿ'(X,Y) := ⟨fun Ο† ⊸ Ο† x β€’ dy + cderiv R Ο† x dx β€’ y⟩ +def diracDeriv (x dx : X) : π’Ÿ' X := ⟨fun Ο† ⊸ cderiv R Ο† x dx⟩ @[fun_prop] def DistribDifferentiableAt (f : X β†’ π’Ÿ'(Y,Z)) (x : X) := @@ -43,15 +44,25 @@ def DistribDifferentiable (f : X β†’ π’Ÿ'(Y,Z)) := βˆ€ x, DistribDifferentiableAt f x +-- TODO: +-- probably change the definition of `parDistribDeriv` to: +-- ⟨⟨fun Ο† => +-- if h : DistribDifferentiableAt f x then +-- βˆ‚ (x':=x;dx), βŸͺf x', Ο†βŸ« +-- else +-- 0 , sorry_proof⟩⟩ +-- I believe in that case the function is indeed linear in Ο† + open Classical in @[fun_trans] noncomputable def parDistribDeriv (f : X β†’ π’Ÿ'(Y,Z)) (x dx : X) : π’Ÿ'(Y,Z) := - ⟨⟨fun Ο† => - if _ : DistribDifferentiableAt f x then - βˆ‚ (x':=x;dx), βŸͺf x', Ο†βŸ« - else - 0, sorry_proof⟩⟩ + ⟨⟨fun Ο† => βˆ‚ (x':=x;dx), βŸͺf x', Ο†βŸ«, sorry_proof⟩⟩ + + +@[simp, ftrans_simp] +theorem action_parDistribDeriv (f : X β†’ π’Ÿ'(Y,Z)) (x dx : X) (Ο† : π’Ÿ Y) : + βŸͺparDistribDeriv f x dx, Ο†βŸ« = βˆ‚ (x':=x;dx), βŸͺf x', Ο†βŸ« := rfl ---------------------------------------------------------------------------------------------------- @@ -79,32 +90,28 @@ theorem parDistribDeriv.const_rule (T : π’Ÿ'(X,Y)) : ---------------------------------------------------------------------------------------------------- @[fun_prop] -theorem vecDirac.arg_xy.DistribDiffrentiable_rule - (x : W β†’ X) (y : W β†’ Y) (hx : CDifferentiable R x) (hy : CDifferentiable R y) : - DistribDifferentiable (R:=R) (fun w => vecDirac (x w) (y w)) := by +theorem dirac.arg_xy.DistribDiffrentiable_rule + (x : W β†’ X) (hx : CDifferentiable R x) : + DistribDifferentiable (R:=R) (fun w => dirac (x w)) := by intro x unfold DistribDifferentiableAt intro Ο† hΟ† - simp [action_vecDirac, dirac] + simp [action_dirac, dirac] fun_prop @[fun_trans] -theorem vecDirac.arg_x.parDistribDeriv_rule - (x : W β†’ X) (y : W β†’ Y) (hx : CDifferentiable R x) (hy : CDifferentiable R y) : - parDistribDeriv (R:=R) (fun w => vecDirac (x w) (y w)) +theorem dirac.arg_x.parDistribDeriv_rule + (x : W β†’ X) (hx : CDifferentiable R x) : + parDistribDeriv (R:=R) (fun w => dirac (x w)) = fun w dw => let xdx := fwdDeriv R x w dw - let ydy := fwdDeriv R y w dw - vecDiracDeriv xdx.1 xdx.2 ydy.1 ydy.2 := by --= (dpure (R:=R) ydy.1 ydy.2) := by + diracDeriv xdx.1 xdx.2 := by --= (dpure (R:=R) ydy.1 ydy.2) := by funext w dw; ext Ο† - unfold parDistribDeriv vecDirac vecDiracDeriv + unfold parDistribDeriv dirac diracDeriv simp [pure, fwdDeriv, DistribDifferentiableAt] fun_trans - . intro Ο†' hΟ†' h - have : CDifferentiableAt R (fun w : W => (Ο†' w) (x w) β€’ (y w)) w := by fun_prop - contradiction ---------------------------------------------------------------------------------------------------- @@ -176,6 +183,49 @@ theorem Bind.bind.arg_fx.parDistribDiff_rule +---------------------------------------------------------------------------------------------------- +-- Move these around ------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + +@[fun_prop] +theorem Distribution.restrict.arg_T.IsSmoothLinearMap_rule (T : W β†’ π’Ÿ'(X,Y)) (A : Set X) + (hT : IsSmoothLinearMap R T) : + IsSmoothLinearMap R (fun w => (T w).restrict A) := sorry_proof + +@[fun_prop] +theorem Distribution.restrict.arg_T.IsSmoothLinearMap_rule_simple (A : Set X) : + IsSmoothLinearMap R (fun (T : π’Ÿ'(X,Y)) => T.restrict A) := sorry_proof + +@[fun_prop] +theorem Function.toDistribution.arg_f.CDifferentiable_rule (f : W β†’ X β†’ Y) + (hf : βˆ€ x, CDifferentiable R (f Β· x)) : + CDifferentiable R (fun w => (fun x => f w x).toDistribution (R:=R)) := sorry_proof + +@[fun_trans] +theorem Function.toDistribution.arg_f.cderiv_rule (f : W β†’ X β†’ Y) + (hf : βˆ€ x, CDifferentiable R (f Β· x)) : + cderiv R (fun w => (fun x => f w x).toDistribution (R:=R)) + = + fun w dw => + (fun x => + let dy := cderiv R (f Β· x) w dw + dy).toDistribution := sorry_proof + +@[fun_trans] +theorem toDistribution.linear_parDistribDeriv_rule (f : W β†’ X β†’ Y) (L : Y β†’ Z) + (hL : IsSmoothLinearMap R L) : + parDistribDeriv (fun w => (fun x => L (f w x)).toDistribution) + = + fun w dw => + parDistribDeriv Tf w dw |>.postComp L := by + funext w dw + unfold parDistribDeriv Distribution.postComp Function.toDistribution + ext Ο† + simp [ftrans_simp, Distribution.mk_extAction_simproc] + sorry_proof + + + ---------------------------------------------------------------------------------------------------- -- Integral ---------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------------------------- @@ -201,6 +251,8 @@ theorem cintegral.arg_f.cderiv_distrib_rule' (f : W β†’ X β†’ R) (A : Set X): -- (parDistribDeriv (fun w => (f w Β·).toDistribution) w dw).extAction (fun x => if x ∈ A then 1 else 0) := sorry_proof + + @[fun_trans] theorem cintegral.arg_f.parDistribDeriv_rule (f : W β†’ X β†’ Y β†’ R) : parDistribDeriv (fun w => (fun x => ∫' y, f w x y).toDistribution) diff --git a/SciLean/Core/Distribution/ParametricDistribRevDeriv.lean b/SciLean/Core/Distribution/ParametricDistribRevDeriv.lean index dea05fa5..64341d05 100644 --- a/SciLean/Core/Distribution/ParametricDistribRevDeriv.lean +++ b/SciLean/Core/Distribution/ParametricDistribRevDeriv.lean @@ -1,4 +1,6 @@ import SciLean.Core.Distribution.ParametricDistribDeriv +import SciLean.Core.Distribution.ParametricDistribFwdDeriv +import SciLean.Core.Distribution.Eval namespace SciLean @@ -20,14 +22,14 @@ variable set_default_scalar R + @[fun_trans] noncomputable def parDistribRevDeriv (f : X β†’ π’Ÿ'(Y,Z)) (x : X) : π’Ÿ'(Y,ZΓ—(Zβ†’X)) := ⟨⟨fun Ο† => let dz := semiAdjoint R (fun dx => βŸͺparDistribDeriv f x dx,Ο†βŸ«) let z := βŸͺf x, Ο†βŸ« - (z, sorry), sorry_proof⟩⟩ - + (z, dz), sorry_proof⟩⟩ namespace parDistribRevDeriv @@ -35,13 +37,22 @@ namespace parDistribRevDeriv theorem comp_rule (f : Y β†’ π’Ÿ'(Z,U)) (g : X β†’ Y) - (hf : DistribDifferentiable f) (hg : CDifferentiable R g) : + (hf : DistribDifferentiable f) (hg : HasAdjDiff R g) : parDistribRevDeriv (fun x => f (g x)) = fun x => let ydg := revDeriv R g x let udf := parDistribRevDeriv f ydg.1 - udf.postComp (fun (u,df') => (u, fun du => ydg.2 (df' du))) := by sorry_proof + udf.postComp (fun (u,df') => (u, fun du => ydg.2 (df' du))) := by + + unfold parDistribRevDeriv + funext x; ext Ο† + simp + fun_trans + simp [action_push,revDeriv,fwdDeriv] + have : βˆ€ x, HasSemiAdjoint R (βˆ‚ x':=x, βŸͺf x', Ο†βŸ«) := sorry_proof -- todo add: `DistribHasAdjDiff` + fun_trans + theorem bind_rule @@ -52,3 +63,51 @@ theorem bind_rule let ydg := parDistribRevDeriv g x let zdf := fun y => parDistribRevDeriv (f Β· y) x ydg.bind' zdf (fun (_,dg) (z,df) => (z, fun dr => dg dr + df dr)) := sorry_proof + + +theorem bind_rule' + (f : X β†’ Y β†’ π’Ÿ'(Z,V)) (g : X β†’ π’Ÿ'(Y,U)) (L : U β†’ V β†’ W) : + parDistribRevDeriv (fun x => (g x).bind' (f x) L) + = + fun x => + let ydg := parDistribRevDeriv g x + let zdf := fun y => parDistribRevDeriv (f Β· y) x + ydg.bind' zdf (fun (u,dg) (v,df) => + (L u v, fun dw => + df (semiAdjoint R (L u Β·) dw) + + dg (semiAdjoint R (L Β· v) dw))) := by + + unfold parDistribRevDeriv bind' + funext x; ext Ο† + simp + sorry_proof + sorry_proof + + + +---------------------------------------------------------------------------------------------------- +-- Dirac ------------------------------------------------------------------------------------------- +---------------------------------------------------------------------------------------------------- + +noncomputable +def diracRevDeriv (x : X) : π’Ÿ'(X,RΓ—(Rβ†’X)) := + ⟨⟨fun Ο† => revDeriv R Ο† x, sorry_proof⟩⟩ + + +@[fun_trans] +theorem dirac.arg_xy.parDistribRevDeriv_rule + (x : W β†’ X) (hx : HasAdjDiff R x) : + parDistribRevDeriv (fun w => dirac (x w)) + = + fun w => + let xdx := revDeriv R x w + diracRevDeriv xdx.1 |>.postComp (fun (r,dΟ†) => (r, fun dr => xdx.2 (dΟ† dr))) := by + + funext w; apply Distribution.ext _ _; intro Ο† + have : HasAdjDiff R Ο† := sorry_proof -- this should be consequence of that `R` has dimension one + simp [diracRevDeriv,revDeriv, parDistribRevDeriv] + fun_trans + + + +#check Distribution.postComp diff --git a/SciLean/Core/Distribution/SimpleExamples.lean b/SciLean/Core/Distribution/SimpleExamples.lean index 3d8acf9b..e49d8eb8 100644 --- a/SciLean/Core/Distribution/SimpleExamples.lean +++ b/SciLean/Core/Distribution/SimpleExamples.lean @@ -34,8 +34,6 @@ theorem _root_.FiniteDimensional.finrank_unit : finrank R Unit = 0 := by sorry_p variable [MeasureSpace R] -- [Module ℝ R] - - def foo1 (t' : R) := (βˆ‚ (t:=t'), ∫' (x:R) in Ioo 0 1, if x ≀ t then (1:R) else 0) rewrite_by fun_trans only [scalarGradient, scalarCDeriv] @@ -49,12 +47,8 @@ theorem foo1_spec (t : R) : #eval foo1 (-1.0) -- 0.0 #eval foo1 2.0 -- 0.0 -#check Set.add_empty - open Classical in -set_option pp.funBinderTypes true in - def foo2 (t' : R) := (βˆ‚ (t:=t'), ∫' (x:R) in Ioo 0 1, if x - t ≀ 0 then (1:R) else 0) rewrite_by fun_trans only [scalarGradient, scalarCDeriv] diff --git a/SciLean/Core/Distribution/SimpleExamples2D.lean b/SciLean/Core/Distribution/SimpleExamples2D.lean index 3eb2165d..fec4a952 100644 --- a/SciLean/Core/Distribution/SimpleExamples2D.lean +++ b/SciLean/Core/Distribution/SimpleExamples2D.lean @@ -75,6 +75,9 @@ def foo1' (t' : R) := simp only [ftrans_simp] simp only [Tactic.if_pull] fun_trans only [scalarGradient, scalarCDeriv,ftrans_simp] + unfold Distribution.postExtAction + rw[postComp_restrict_extAction (x:=dirac t') (A:= Ioo 0 1) (Ο†:=fun _ => 1)] + simp [ftrans_simp, postComp_restrict_extAction] simp (disch:=sorry) only [action_push, ftrans_simp] rand_pull_E simp @@ -82,15 +85,6 @@ def foo1' (t' : R) := #eval Rand.print_mean_variance (foo1' 0.3) 100 " of foo1'" --- open Scalar in --- def foo1'' (t' : R) := --- derive_random_approx --- (βˆ‚ (t:=t'), ∫' (x : R) in Ioo 0 1, sqrt (∫' (y : R) in Ioo 0 1, if x ≀ t then (1:R) else 0)) --- by --- fun_trans only [scalarGradient, scalarCDeriv, if_pull, ftrans_simp] --- simp only [action_push, ftrans_simp] - - def foo2 (t' : R) := derive_random_approx (βˆ‚ (t:=t'), ∫' (xy : RΓ—R) in (Ioo 0 1).prod (Ioo 0 1), if xy.1 + xy.2 ≀ t then (1:R) else 0) @@ -116,8 +110,6 @@ def foo2 (t' : R) := simp (disch:=sorry) only [ftrans_simp] rand_pull_E -π’Ÿ'(X,π’Ÿ'(Y,ℝ)) := L(π’Ÿ X, Y) - #eval Rand.print_mean_variance (foo2 0.3) 1000 "" #eval Rand.print_mean_variance (foo2 1.7) 1000 "" @@ -155,12 +147,19 @@ def foo3 (t' : R) := #eval Rand.print_mean_variance (foo3 1.7) 10000 "" +variable [Module ℝ Z] [MeasureSpace X] [Module ℝ Y] + + +#exit + +set_option profiler true in +set_option trace.Meta.Tactic.fun_trans true in +set_option trace.Meta.Tactic.fun_prop true in def foo4 (t' : R) := derive_random_approx (βˆ‚ (t:=t'), ∫' (x : R) in Ioo 0 1, ∫' (y : R) in Ioo 0 1, if x ≀ t then x*y*t else x+y+t) by fun_trans only [scalarGradient, scalarCDeriv] - simp only [ftrans_simp] simp only [Tactic.if_pull] fun_trans only [scalarGradient, scalarCDeriv,ftrans_simp] simp (disch:=sorry) only [action_push, ftrans_simp] diff --git a/SciLean/Core/Distribution/SurfaceDirac.lean b/SciLean/Core/Distribution/SurfaceDirac.lean index c99d82ad..af3cceba 100644 --- a/SciLean/Core/Distribution/SurfaceDirac.lean +++ b/SciLean/Core/Distribution/SurfaceDirac.lean @@ -35,7 +35,7 @@ theorem surfaceDirac_extAction (A : Set X) (f : X β†’ Y) (d : β„•) (Ο† : X β†’ R @[simp, ftrans_simp] -theorem surfaceDirac_dirac (f : X β†’ Y) (x : X) : surfaceDirac {x} f 0 = vecDirac x (f x) := by +theorem surfaceDirac_dirac (f : X β†’ Y) (x : X) : surfaceDirac {x} f 0 = (dirac x).postComp (fun r => r β€’ (f x)) := by ext Ο† unfold surfaceDirac; dsimp sorry_proof