Skip to content

Commit

Permalink
attempt at array traces for probabilistic programs
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Jul 17, 2024
1 parent 9240c79 commit ab76968
Showing 1 changed file with 158 additions and 6 deletions.
164 changes: 158 additions & 6 deletions SciLean/Core/Rand/RandWithTrace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,23 @@ inductive Trace where
| nil
| single (tag : Name) (T : Type)
| pair (t s : Trace)
| array (t : Trace) (n : Nat)

instance : Append Trace := ⟨fun t s => .pair t s⟩

def Trace.type : (trace : Trace) → Type
| .nil => Unit
| .single n T => T
| .pair t s => t.type × s.type
| .array t n => ArrayN.{0,0} t.type n -- {a : Array t.type // a.size = n}

def Trace.tags : (trace : Trace) → List Name
| .nil => []
| .single n _ => [n]
| .pair t s => t.tags ++ s.tags
| .array t n =>
let ts := t.tags.map (fun tag => (List.range n).map (fun i => tag.append (.mkSimple (toString i))))
ts.foldl (init:=[]) (·++·)

structure RandWithTrace (X : Type) (trace : Trace) (T : Type) where
rand : Rand X
Expand All @@ -39,8 +44,23 @@ def return' (x : X) : RandWithTrace X .nil Unit where
hmap := by simp
htype := rfl

macro "trace_bind_tac" : tactic =>
`(tactic| first | native_decide
| simp only [Trace.tags,
List.inter,
List.elem_eq_mem,
List.find?_nil,
List.filter_cons_of_neg,
List.filter_nil,
Bool.false_eq_true,
not_false_eq_true
decide_False] )



def RandWithTrace.bind (x : RandWithTrace X t T) (f : X → RandWithTrace Y s S)
(hinter : t.tags.inter s.tags = [] := by simp[Trace.tags,List.inter]; done) : RandWithTrace Y (t++s) (T×S) where
(hinter : t.tags.inter s.tags = [] := by trace_bind_tac) :
RandWithTrace Y (t++s) (T×S) where
rand := x.rand >>= (fun x' => (f x').rand)
traceRand := do
let tx ← x.traceRand
Expand Down Expand Up @@ -123,8 +143,33 @@ def test2 :=
variable {R} [RealScalar R]
open MeasureTheory


def forLoop (f : Nat → X → RandWithTrace X t T) (init : X) (n : Nat) :
RandWithTrace X (.array t n) (ArrayN.{0,0} T n) where
rand := do
let mut x := init
for i in [0:n] do
x ← (f i x).rand
return x
traceRand := do
let mut ws : Array T := #[]
let mut x := init
for i in [0:n] do
let w ← (f i x).traceRand
ws := ws.push w
x := (f i x).map w
return ⟨ws, sorry_proof⟩
map := fun ws => Id.run do
let mut x := init
for w in ws.1, i in [0:n] do
x := (f i x).map w
return x
hmap := sorry_proof
htype := by sorry_proof

@[simp]
theorem trace_bind_pdf (x : RandWithTrace X t T) (f : X → RandWithTrace Y s S)
(h : t.tags.inter s.tags = [] := by simp[Trace.tags,List.inter])
(h : t.tags.inter s.tags = [] := by trace_bind_tac)
[MeasureSpace S] [MeasureSpace T] :
((x.bind f h).traceRand).pdf R
=
Expand Down Expand Up @@ -154,19 +199,126 @@ theorem trace_sample_pdf (x : Rand X) (n : Name) [MeasureSpace X] :
simp


set_option trace.Meta.Tactic.simp.unify true
set_option trace.Meta.Tactic.simp.discharge true
set_option trace.Meta.Tactic.simp.rewrite true in
#check (let x <~ sample (normal (0.0:Float) 1.0) `v1; return' x)


#check ((let x <~ sample (normal (0.0:Float) 1.0) `v1; return' x).traceRand.pdf Float)
rewrite_by
simp[trace_bind_pdf]


set_option pp.funBinderTypes true
def tt :=
(let (_,x) <~ forLoop (init:=(0.0,#[])) (n:=50) (fun i (x,xs) =>
let x' <~ sample (normal x 1.0) `v1
return' (x', xs.push x'))
let y <~ sample (normal x.1.sum 1.0) `y
return' y)


instance (n:Nat) [MeasureSpace X] : MeasureSpace {a : Array X // a.size = n} := sorry
instance (n:Nat) [MeasureSpace X] : MeasureSpace (ArrayN.{_,0} X n) := sorry


#check (tt.traceRand.pdf Float) rewrite_by
unfold tt
simp


#check (
(let x <~ sample (normal (0.0:Float) 1.0) `v1
let y <~ sample (normal x 1.0) `v2
return' (x*y)).traceRand.pdf Float)
return' y).traceRand.pdf Float)
rewrite_by
simp



#check ((let x <~ sample (normal (0.0:Float) 1.0) `v1;
return' x).traceRand.pdf Float)
rewrite_by
simp[trace_bind_pdf]


#check let x <~ sample (normal (0.0:Float) 1.0) `v1;
return' x


def temperature :=
(let tempLower := 2.0
let tempUpper := 4.0
let invCR := 0.1
let tempA := 3.0
let rpRate := 5.0
let sigma1 := 1.0
let sigma0 := 2.0
let sqrtTimeStep := 0.1

let aconInit := false
let acons := #[aconInit]
let tempInit <~ sample (normal 2.0 0.001) `temp_init
let temps := #[tempInit]
let dataInit <~ sample (normal tempInit 1.0) `data_init
let data := #[dataInit]

let (_,_,temps,acons,data) <~
forLoop (init:=(tempInit,aconInit,temps,acons,data)) (n:=20)
(fun i (tempPrev,aconPrev,temps,acons,data) =>
let m :=
if tempPrev < tempLower then false
else if tempPrev > tempUpper then true
else aconPrev

let aconNoise <~ sample (normal m.toNat.toFloat 0.001) `acon_noise
let acon := if aconNoise > 0.5 then true else false

let b := invCR * (tempA - (tempPrev + acon.toNat.toFloat * rpRate))
let s := if aconNoise > 0.5 then sigma1 else sigma0

let temp <~ sample (normal (tempPrev + b) (s * sqrtTimeStep)) `temp

let dataPoint <~ sample (normal temp 1.0) `data

return' (tempPrev,aconPrev,temps.push temp,acons.push acon,data.push dataPoint))
return' ())

-- @[gtrans outparams tr₁ ...]

structure HasConditionalRand {X tr T} (x : RandWithTrace X tr T) (tags : List Name)
(tr₁ tr₂ : Trace) (T₁ T₂ : Type)
(p : T → T₁×T₂) (q : T₁ → T₂ → T)
(y : RandWithTrace T₁ tr₁ T₁) (z : T₁ → RandWithTrace X tr₂ T₂) : Prop where
trace_type₁ : tr₁.type = T₁
trace_type₂ : tr₂.type = T₂
left_inv : Function.LeftInverse p ↿q
right_inv : Function.RightInverse p ↿q
trace_tags : tr₁.tags = tags
trace_union : tr₁ ++ tr₂ = tr
trace_inter : tr₁.tags.inter tr₂.tags = []
hrand : x.traceRand = (do
let ty ← y.traceRand;
let tz ← (z ty).traceRand
return q ty tz)
hmap : x.map = fun w =>
let (u,v) := u
(z (y.map u)).map v


theorem HasConditionalRand.empty_rule (x : RandWithTrace X tr T) :
HasConditionalRand x []
.nil tr Unit T (fun t => ((),t)) (fun _ t => t)
(return' ()) (fun _ => x) := sorry_proof


theorem HasConditionalRand.bind_rule
(x : RandWithTrace X t T) (f : X → RandWithTrace Y s S) (tags : List Name)
(hinter : t.tags.inter s.tags = [])
(hx : HasConditionalRand x (t.tags.inter tags) t₁ t₂ T₁ T₂ px qx x₁ x₂)
{y₁ : X → RandWithTrace Unit s₁ S₁} {y₂ : X → S₁ → RandWithTrace Y s₂ S₂}
(hy : ∀ x, HasConditionalRand (f x) (s.tags.inter tags) s₁ s₂ S₁ S₂ py qy (y₁ x) (y₂ x)) :
HasConditionalRand (x.bind f hinterx) tags
(t₁++s₁) (t₂++s₂) (T₁×S₁) (T₂×S₂)
(fun (u,v) => (((px u).1, (py v).1),((px u).2, (py v).2)))
(fun (u₁,v₁) (u₂,v₂) => (qx u₁ u₂, qy v₁ v₂))
(x₁.bind (fun tx => y₁ (x₁.map tx)) sorry)
sorry := sorry_proof

0 comments on commit ab76968

Please sign in to comment.