Skip to content

Commit

Permalink
clean up gaussian function and exp and log simp theorems
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 7, 2024
1 parent 2da870e commit 88566a9
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 177 deletions.
5 changes: 5 additions & 0 deletions SciLean/Algebra/Dimension.lean
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ elab "dim(" X:term ")" : term => do
let (dim,_) ← elabConvRewrite dim #[] (← `(conv| simp -failIfUnchanged))
return dim

@[simp, simp_core]
theorem finrank_dimension {R X d} [Ring R] [AddCommGroup X] [Module R X] [hd : Dimension R X d] :
Module.finrank R X = d := hd.is_dim



instance : Dimension ℝ ℝ 1 where
is_dim := by simp
Expand Down
51 changes: 34 additions & 17 deletions SciLean/Analysis/SpecialFunctions/Exp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import SciLean.Analysis.Calculus.RevFDeriv
import SciLean.Analysis.Calculus.RevCDeriv
import SciLean.Analysis.Calculus.FwdFDeriv
import SciLean.Analysis.Calculus.FwdCDeriv
import SciLean.Analysis.Calculus.ContDiff

open ComplexConjugate

Expand All @@ -15,24 +16,12 @@ variable
{U} [SemiInnerProductSpace C U]


--------------------------------------------------------------------------------
-- Exp -------------------------------------------------------------------------
--------------------------------------------------------------------------------

set_option linter.unusedVariables false in
@[fun_prop]
theorem exp.arg_x.DifferentiableAt_rule
{W} [NormedAddCommGroup W] [NormedSpace C W]
(w : W) (x : W → C) (hx : DifferentiableAt C x w) :
DifferentiableAt C (fun w => exp (x w)) w := sorry_proof


@[fun_prop]
theorem exp.arg_x.Differentiable_rule
{W} [NormedAddCommGroup W] [NormedSpace C W]
(x : W → C) (hx : Differentiable C x) :
Differentiable C fun w => exp (x w) := by intro x; fun_prop
----------------------------------------------------------------------------------------------------
-- Exp ---------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------

def_fun_prop exp in x with_transitive : Differentiable K by sorry_proof
def_fun_prop exp in x with_transitive : ContDiff K ⊤ by sorry_proof

set_option linter.unusedVariables false in
@[fun_trans]
Expand Down Expand Up @@ -128,3 +117,31 @@ theorem exp.arg_x.revCDeriv_rule
(exp xdx.1, fun dy => xdx.2 (conj (exp xdx.1) * dy)) := by
unfold revCDeriv
fun_trans [fwdCDeriv, smul_push, simp_core]




@[simp, simp_core, exp_push]
theorem exp_zero : exp (0:R) = 1 := sorry_proof
@[simp, simp_core, exp_push]
theorem exp_log (x : R) : exp (log x) = abs x := sorry_proof

@[exp_push]
theorem exp_add (x y : R) : exp (x+y) = exp x * exp y := sorry_proof
@[exp_pull]
theorem mul_exp (x y : R) : exp x * exp y = exp (x+y) := sorry_proof

@[exp_push]
theorem exp_sub (x y : R) : exp (x-y) = exp x / exp y := sorry_proof
@[exp_pull]
theorem div_exp (x y : R) : exp x / exp y = exp (x-y) := sorry_proof

@[exp_push]
theorem exp_inv (x : R) : exp (-x) = (exp x)⁻¹ := sorry_proof
@[exp_pull]
theorem inv_exp (x : R) : (exp x)⁻¹ = exp (-x) := sorry_proof

@[exp_push]
theorem exp_mul (x y : R) : (exp x*y) = (exp x)^y := sorry_proof
@[exp_pull]
theorem pow_exp (x y : R) : (exp x)^y = exp (x*y) := sorry_proof
114 changes: 49 additions & 65 deletions SciLean/Analysis/SpecialFunctions/Gaussian.lean
Original file line number Diff line number Diff line change
@@ -1,102 +1,86 @@
import SciLean.Algebra.Dimension
import SciLean.Analysis.Calculus.FDeriv
import SciLean.Analysis.Calculus.ContDiff
import SciLean.Analysis.SpecialFunctions.Exp
import SciLean.Analysis.SpecialFunctions.Log
import SciLean.Analysis.SpecialFunctions.Norm2

import SciLean.Analysis.Calculus.FDeriv

import SciLean.Meta.GenerateFunTrans
import SciLean.Meta.Notation.Let'
import SciLean.Tactic.Autodiff
import SciLean.Lean.ToSSA

open ComplexConjugate
import Mathlib.Probability.Distributions.Gaussian

namespace SciLean

open Scalar RealScalar ComplexConjugate

set_option deprecated.oldSectionVars true

variable
{R C} [Scalar R C] [RealScalar R]
{W} [Vec R W]
{U} [SemiHilbert R U]
{X : Type*} [NormedAddCommGroup X] [AdjointSpace R X] [CompleteSpace X] {d : outParam ℕ} [hdim : Dimension R X d]

set_default_scalar R

----------------------------------------------------------------------------------------------------
-- Gaussian ----------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------

open Scalar RealScalar in
def gaussian {U} [Sub U] [SMul R U] [Inner R U] (μ : U) (σ : R) (x : U) : R :=
def gaussian [Dimension R X d] (μ : X) (σ : R) (x : X) : R :=
let x' := σ⁻¹ • (x - μ)
1/(σ*sqrt (2*(pi : R))) * exp (- ‖x'‖₂²/2)

(2*π*σ^2)^(-(d:R)/2) * exp (- ‖x'‖₂²/2)

open Scalar RealScalar in
@[simp, simp_core]
theorem log_gaussian (μ : U) (σ : R) (x : U) :
theorem log_gaussian (μ : X) (σ : R) (x : X) :
log (gaussian μ σ x)
=
let x' := σ⁻¹ • (x - μ)
(- ‖x'‖₂²/2 - log σ - log (sqrt (2*(pi :R)))) := by
(- d/2 * (log (2*π) + 2 * log σ) - ‖x'‖₂²/2 ) := by

unfold gaussian
simp [log_inv,log_mul,log_div,log_exp,log_one]
simp [log_push]
ring

def_fun_prop with_transitive
{X : Type _} [NormedAddCommGroup X] [AdjointSpace R X] (σ : R) :
Differentiable R (fun (μx : X×X) => gaussian μx.1 σ μx.2) by
unfold gaussian; fun_prop

def_fun_prop with_transitive
{X : Type _} [SemiHilbert R X] (σ : R) :
HasAdjDiff R (fun (μx : X×X) => gaussian μx.1 σ μx.2) by
unfold gaussian; fun_prop


section OnAdjointSpace
def_fun_prop gaussian in μ x with_transitive : Differentiable R

set_option deprecated.oldSectionVars true

variable {U : Type _} [NormedAddCommGroup U] [AdjointSpace R U] [CompleteSpace U]

@[fun_trans]
theorem gaussian.arg_μx.fderiv_rule (σ : R) :
fderiv R (fun μx : U×U => gaussian μx.1 σ μx.2)
=
fun μx => fun dμx =>L[R]
let dx' := - (σ^2)⁻¹ * ⟪dμx.2-dμx.1, μx.2-μx.1
dx' * gaussian μx.1 σ μx.2 := by
ext x dx <;>
(unfold gaussian; simp
conv => lhs; autodiff
simp[smul_pull]
ring)


@[fun_trans]
theorem gaussian.arg_μx.fwdFDeriv_rule (σ : R) :
fwdFDeriv R (fun μx : U×U => gaussian μx.1 σ μx.2)
=
fun μx dμx =>
let x' := gaussian μx.1 σ μx.2
let dx' := - (σ^2)⁻¹ * ⟪dμx.2-dμx.1, μx.2-μx.1
(x', dx' * x') := by
abbrev_fun_trans gaussian in μ x : fderiv R by
equals (fun μx => fun dμx =>L[R]
let' (μ,x) := μx
let' (dμ,dx) := dμx
let dx' := - (σ^2)⁻¹ * ⟪dx-dμ, x-μ⟫[R]
dx' * gaussian μ σ x) =>
unfold gaussian
fun_trans
funext x;
ext dx <;> (simp[smul_pull]; ring)


abbrev_fun_trans gaussian in μ x : fwdFDeriv R by
-- ideally
-- unfold fwdFDeriv
-- autodiff
-- run common subexpression elimination
equals (fun μx dμx =>
let' (μ,x) := μx
let' (dμ,dx) := dμx
let dx' := - (σ^2)⁻¹ * ⟪dx-dμ, x-μ⟫[R]
let G := gaussian μ σ x
(G, dx' * G)) =>
unfold fwdFDeriv
fun_trans


@[fun_trans]
theorem gaussian.arg_μx.revFDeriv_rule (σ : R) :
revFDeriv R (fun μx : U×U => gaussian μx.1 σ μx.2)
=
fun μx =>
let s := gaussian μx.1 σ μx.2
(s, fun dr =>
let dx := (dr * s * (σ^2)⁻¹) • (μx.1-μx.2)
(- dx, dx)) := by
abbrev_fun_trans gaussian in μ x [CompleteSpace X] : revFDeriv R by
equals (fun μx =>
let' (μ,x) := μx
let G := gaussian μ σ x
(G, fun dr =>
let dx := (G*(σ^2)⁻¹*dr) • (x-μ)
(dx,-dx))) =>
unfold revFDeriv
funext μx; simp; funext dr
fun_trans [smul_smul,neg_push];
ring_nf
simp [smul_sub,neg_sub]

end OnAdjointSpace
funext x; fun_trans
funext dx; simp only [Prod.mk.injEq, neg_inj]
constructor <;> module
45 changes: 37 additions & 8 deletions SciLean/Analysis/SpecialFunctions/Log.lean
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,42 @@ theorem log.arg_x.revCDeriv_rule

end Convenient


@[simp, simp_core]
theorem log_one : Scalar.log (1:R) = 0 := sorry_proof
@[simp, simp_core]
theorem log_exp (x : R) : Scalar.log (Scalar.exp x) = x := sorry_proof
theorem log_mul (x y : R) : Scalar.log (x*y) = Scalar.log x + Scalar.log y := sorry_proof
theorem log_div (x y : R) : Scalar.log (x/y) = Scalar.log x - Scalar.log y := sorry_proof
theorem log_inv (x : R) : Scalar.log x⁻¹ = - Scalar.log x := sorry_proof
open Scalar

@[simp, simp_core, log_push]
theorem log_one : log (1:R) = 0 := sorry_proof
@[simp, simp_core, log_push]
theorem log_exp (x : R) : log (exp x) = x := sorry_proof

@[log_push]
theorem log_mul (x y : R) : log (x*y) = log x + log y := sorry_proof
@[log_pull]
theorem add_log (x y : R) : log x + log y = log (x*y) := sorry_proof

@[log_push]
theorem log_div (x y : R) : log (x/y) = log x - log y := sorry_proof
@[log_pull]
theorem sub_log (x y : R) : log x - log y = log (x/y) := sorry_proof

@[log_push]
theorem log_inv (x : R) : log (x⁻¹) = - log x := sorry_proof
@[log_pull]
theorem neg_log (x : R) : - log x = log (x⁻¹) := sorry_proof

@[log_push]
theorem log_pow (x y : R) : log (x^y) = y * log x := sorry_proof
@[log_push]
theorem log_pow_nat (x : R) (n : ℕ) : log (x^n) = n * log x := sorry_proof
@[log_push]
theorem log_pow_int (x : R) (n : ℤ) : log (x^n) = n * log x := sorry_proof
@[log_pull]
theorem mul_log (x y : R) : y * log x = log (x^y) := sorry_proof
@[log_pull]
theorem mul_log' (x y : R) : (log x) * y = log (x^y) := sorry_proof

@[log_push]
theorem log_prod {I} [IndexType I] (f : I → R) : log (∏ i, f i) = ∑ i, log (f i) := sorry_proof
@[log_pull]
theorem sum_log {I} [IndexType I] (f : I → R) : (∑ i, log (f i)) = log (∏ i, f i) := sorry_proof

end Log
5 changes: 5 additions & 0 deletions SciLean/Data/ArrayType/Algebra.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import SciLean.Analysis.Convenient.FinVec
import SciLean.Analysis.AdjointSpace.Basic
import SciLean.Analysis.Scalar.FloatAsReal

import SciLean.Algebra.Dimension

import SciLean.Data.ArrayType.Basic
import SciLean.Data.StructType.Algebra

Expand Down Expand Up @@ -268,6 +270,9 @@ instance [ArrayType Cont Idx Elem] [MeasurableSpace Elem] [TopologicalSpace Elem
measurable_eq := sorry_proof


instance {d} [ArrayType Cont Idx Elem] [AddCommGroup Elem] [Module K Elem] [Dimension K Elem d] :
Dimension K Cont ((size Idx)*d) where
is_dim := by conv => lhs; simp

-- This is problem as `Vec` and `NormedAddCommGroup` provide different topologie on `Elem`
-- example {R} [RCLike R] [ArrayType Cont Idx Elem] [NormedAddCommGroup Elem] [NormedSpace ℝ Elem] [Vec R Elem] :
Expand Down
46 changes: 21 additions & 25 deletions SciLean/Data/DataArray/Operations/GaussianN.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,44 +18,47 @@ open Scalar RealScalar
The reason why it is symbolic is that you do not want to compute deteminant and inverse of `σ`. -/
noncomputable
def gaussianN {n : ℕ} (μ : R^[n]) (S : R^[n,n]) (x : R^[n]) : R :=
def gaussianS {n : ℕ} (μ : R^[n]) (S : R^[n,n]) (x : R^[n]) : R :=
let x' := x-μ
(2*π)^(-(n:R)/2) * S.det^(-(1:R)/2) * exp (- ⟪x', (S⁻¹)*x'⟫/2)
#check Module.finrank

def_fun_prop gaussianN in μ x : Differentiable R

def_fun_prop gaussianS in μ x : Differentiable R

abbrev_fun_trans gaussianN in μ x : fderiv R by

abbrev_fun_trans gaussianS in μ x : fderiv R by
equals (fun μx => fun dμx : R^[n]×R^[n] =>L[R]
let' (μ,x) := μx
let' (dμ,dx) := dμx
let x' := x-μ
let dx' := dx-dμ
(-2⁻¹)*(⟪dx',S⁻¹*x'⟫[R] + ⟪x',S⁻¹*dx'⟫[R])*gaussianN μ S x) =>
unfold gaussianN
let G := gaussianS μ S x
let ds := ⟪dx',S⁻¹*x'⟫[R] + ⟪x',S⁻¹*dx'⟫[R]
(-2⁻¹)*ds*G) =>
unfold gaussianS
fun_trans
funext x; dsimp;
ext dx <;> (simp; ring)


abbrev_fun_trans gaussianN in μ x : fwdFDeriv R by
abbrev_fun_trans gaussianS in μ x : fwdFDeriv R by
equals (fun μx dμx : R^[n] × R^[n] =>
let' (μ,x) := μx
let' (dμ,dx) := dμx
let x' := x-μ
let dx' := dx-dμ
let G := gaussianN μ S x
(G, (-2⁻¹)*(⟪dx',S⁻¹*x'⟫[R] + ⟪x',S⁻¹*dx'⟫[R])*G)) =>
let G := gaussianS μ S x
let ds := ⟪dx',S⁻¹*x'⟫[R] + ⟪x',S⁻¹*dx'⟫[R]
(G, (-2⁻¹)*ds*G)) =>
unfold fwdFDeriv
fun_trans


abbrev_fun_trans gaussianN in μ x : revFDeriv R by
abbrev_fun_trans gaussianS in μ x : revFDeriv R by
equals (fun μx : R^[n] × R^[n] =>
let' (μ,x) := μx
let x' := x-μ
let G := gaussianN μ S x
let G := gaussianS μ S x
(G, fun dr =>
let dx := (-2⁻¹*dr)•(S⁻ᵀ*x' + S⁻¹*x')
(-G•dx,G•dx))) =>
Expand All @@ -77,9 +80,9 @@ omit [PlainDataType R] in
theorem RealScalar.one_pow (x : R) : (1:R)^x = 1 := sorry_proof


theorem gaussianN_ATA {μ : R^[n]} {A : R^[n,n]} {x : R^[n]} (hA : A.Invertible) :
gaussianN μ ((Aᵀ*A)⁻¹) x = A.det * gaussianN 0 𝐈 (A*(x-μ)) := by
unfold gaussianN
theorem gaussianS_ATA' {μ : R^[n]} {A : R^[n,n]} {x : R^[n]} (hA : A.Invertible) :
gaussianS μ ((Aᵀ*A)⁻¹) x = A.det * gaussianS 0 𝐈 (A*(x-μ)) := by
unfold gaussianS
simp (disch:=simp[hA]) only [det_inv_eq_inv_det, det_mul, det_transpose, mul_inv_rev,
DataArrayN.inv_inv, vecmul_assoc, transpose_transpose, inner_self, det_identity, mul_one,
sub_zero, inv_identity,identity_vecmul, mul_eq_mul_right_iff,RealScalar.one_pow,inner_ATA_right]
Expand All @@ -88,16 +91,9 @@ theorem gaussianN_ATA {μ : R^[n]} {A : R^[n,n]} {x : R^[n]} (hA : A.Invertible)
simp[h]


-- NOTE: gaussian - has incorrect definition right now !!!
theorem gaussianN_ATA' (μ : R^[n]) (A : R^[n,n]) (hA : A.Invertible) (x : R^[n]) :
gaussianN μ ((Aᵀ*A)⁻¹) x = A.det * gaussian 0 1 (A*(x-μ)) := by
theorem gaussianS_ATA (μ : R^[n]) (A : R^[n,n]) (hA : A.Invertible) (x : R^[n]) :
gaussianS μ ((Aᵀ*A)⁻¹) x = A.det * gaussian 0 1 (A*(x-μ)) := by

rw[gaussianN_ATA hA]
unfold gaussian gaussianN
rw[gaussianS_ATA' hA]
unfold gaussian gaussianS
simp
have h : (sqrt (2 * π))⁻¹ = (2*π)^(-(1:R)/2) := sorry
rw[h]
ring_nf
sorry_proof -- almost done

#check gaussian
Loading

0 comments on commit 88566a9

Please sign in to comment.