Skip to content

Commit

Permalink
feat: Array.forIn', and relate to List (#5833)
Browse files Browse the repository at this point in the history
Adds support for `for h : x in my_array do`, and relates this to the
existing `List` version.
  • Loading branch information
kim-em authored Oct 25, 2024
1 parent 193b6f2 commit 07ea626
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 27 deletions.
22 changes: 22 additions & 0 deletions src/Init/Control/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,28 @@ import Init.Core

universe u v w

/--
A `ForIn'` instance, which handles `for h : x in c do`,
can also handle `for x in x do` by ignoring `h`, and so provides a `ForIn` instance.
-/
instance (priority := low) instForInOfForIn' [ForIn' m ρ α d] : ForIn m ρ α where
forIn x b f := forIn' x b fun a _ => f a

@[simp] theorem forIn'_eq_forIn [d : Membership α ρ] [ForIn' m ρ α d] {β} [Monad m] (x : ρ) (b : β)
(f : (a : α) → a ∈ x → β → m (ForInStep β)) (g : (a : α) → β → m (ForInStep β))
(h : ∀ a m b, f a m b = g a b) :
forIn' x b f = forIn x b g := by
simp [instForInOfForIn']
congr
apply funext
intro a
apply funext
intro m
apply funext
intro b
simp [h]
rfl

@[reducible]
def Functor.mapRev {f : Type u → Type v} [Functor f] {α β : Type u} : f α → (α → β) → f β :=
fun a f => f <$> a
Expand Down
1 change: 0 additions & 1 deletion src/Init/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ class ForIn' (m : Type u₁ → Type u₂) (ρ : Type u) (α : outParam (Type v)

export ForIn' (forIn')


/--
Auxiliary type used to compile `do` notation. It is used when compiling a do block
nested inside a combinator like `tryCatch`. It encodes the possible ways the
Expand Down
47 changes: 47 additions & 0 deletions src/Init/Data/Array/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,22 @@ theorem ext' {as bs : Array α} (h : as.toList = bs.toList) : as = bs := by

@[simp] theorem getElem_toList {a : Array α} {i : Nat} (h : i < a.size) : a.toList[i] = a[i] := rfl

/-- `a ∈ as` is a predicate which asserts that `a` is in the array `as`. -/
-- NB: This is defined as a structure rather than a plain def so that a lemma
-- like `sizeOf_lt_of_mem` will not apply with no actual arrays around.
structure Mem (as : Array α) (a : α) : Prop where
val : a ∈ as.toList

instance : Membership α (Array α) where
mem := Mem

theorem mem_def {a : α} {as : Array α} : a ∈ as ↔ a ∈ as.toList :=
fun | .mk h => h, Array.Mem.mk⟩

@[simp] theorem getElem_mem {l : Array α} {i : Nat} (h : i < l.size) : l[i] ∈ l := by
rw [Array.mem_def, ← getElem_toList]
apply List.getElem_mem

end Array

namespace List
Expand Down Expand Up @@ -316,6 +332,37 @@ protected def forIn {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m
instance : ForIn m (Array α) α where
forIn := Array.forIn

/-- See comment at `forInUnsafe` -/
@[inline] unsafe def forIn'Unsafe {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : Array α) (b : β) (f : (a : α) → a ∈ as → β → m (ForInStep β)) : m β :=
let sz := as.usize
let rec @[specialize] loop (i : USize) (b : β) : m β := do
if i < sz then
let a := as.uget i lcProof
match (← f a lcProof b) with
| ForInStep.done b => pure b
| ForInStep.yield b => loop (i+1) b
else
pure b
loop 0 b

/-- Reference implementation for `forIn'` -/
@[implemented_by Array.forIn'Unsafe]
protected def forIn' {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : Array α) (b : β) (f : (a : α) → a ∈ as → β → m (ForInStep β)) : m β :=
let rec loop (i : Nat) (h : i ≤ as.size) (b : β) : m β := do
match i, h with
| 0, _ => pure b
| i+1, h =>
have h' : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self i) h
have : as.size - 1 < as.size := Nat.sub_lt (Nat.zero_lt_of_lt h') (by decide)
have : as.size - 1 - i < as.size := Nat.lt_of_le_of_lt (Nat.sub_le (as.size - 1) i) this
match (← f as[as.size - 1 - i] (getElem_mem this) b) with
| ForInStep.done b => pure b
| ForInStep.yield b => loop i (Nat.le_of_lt h') b
loop as.size (Nat.le_refl _) b

instance : ForIn' m (Array α) α inferInstance where
forIn' := Array.forIn'

/-- See comment at `forInUnsafe` -/
@[inline]
unsafe def foldlMUnsafe {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : β → α → m β) (init : β) (as : Array α) (start := 0) (stop := as.size) : m β :=
Expand Down
45 changes: 33 additions & 12 deletions src/Init/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ namespace Array

@[simp] theorem getElem_mk {xs : List α} {i : Nat} (h : i < xs.length) : (Array.mk xs)[i] = xs[i] := rfl

theorem getElem_eq_getElem_toList {a : Array α} (h : i < a.size) : a[i] = a.toList[i] := by
by_cases i < a.size <;> (try simp [*]) <;> rfl
theorem getElem_eq_getElem_toList {a : Array α} (h : i < a.size) : a[i] = a.toList[i] := rfl

theorem getElem?_eq_getElem {a : Array α} {i : Nat} (h : i < a.size) : a[i]? = some a[i] :=
getElem?_pos ..
Expand Down Expand Up @@ -85,6 +84,9 @@ We prefer to pull `List.toArray` outwards.
(a.toArrayAux b).size = b.size + a.length := by
simp [size]

@[simp] theorem mem_toArray {a : α} {l : List α} : a ∈ l.toArray ↔ a ∈ l := by
simp [mem_def]

@[simp] theorem push_toArray (l : List α) (a : α) : l.toArray.push a = (l ++ [a]).toArray := by
apply ext'
simp
Expand Down Expand Up @@ -121,6 +123,30 @@ We prefer to pull `List.toArray` outwards.
rw [Array.forIn, forIn_loop_toArray]
simp

@[simp] theorem forIn'_loop_toArray [Monad m] (l : List α) (f : (a : α) → a ∈ l.toArray → β → m (ForInStep β)) (i : Nat)
(h : i ≤ l.length) (b : β) :
Array.forIn'.loop l.toArray f i h b =
forIn' (l.drop (l.length - i)) b (fun a m b => f a (by simpa using mem_of_mem_drop m) b) := by
induction i generalizing l b with
| zero =>
simp [Array.forIn'.loop]
| succ i ih =>
simp only [Array.forIn'.loop, size_toArray, getElem_toArray, ih, forIn_eq_forIn]
have t : drop (l.length - (i + 1)) l = l[l.length - i - 1] :: drop (l.length - i) l := by
simp only [Nat.sub_add_eq]
rw [List.drop_sub_one (by omega), List.getElem?_eq_getElem (by omega)]
simp only [Option.toList_some, singleton_append]
simp [t]
have t : l.length - 1 - i = l.length - i - 1 := by omega
simp only [t]
congr

@[simp] theorem forIn'_toArray [Monad m] (l : List α) (b : β) (f : (a : α) → a ∈ l.toArray → β → m (ForInStep β)) :
forIn' l.toArray b f = forIn' l b (fun a m b => f a (mem_toArray.mpr m) b) := by
change Array.forIn' _ _ _ = List.forIn' _ _ _
rw [Array.forIn', forIn'_loop_toArray]
simp [List.forIn_eq_forIn]

theorem foldrM_toArray [Monad m] (f : α → β → m β) (init : β) (l : List α) :
l.toArray.foldrM f init = l.foldrM f init := by
rw [foldrM_eq_reverse_foldlM_toList]
Expand Down Expand Up @@ -268,9 +294,6 @@ theorem anyM_stop_le_start [Monad m] (p : α → m Bool) (as : Array α) (start
(h : min stop as.size ≤ start) : anyM p as start stop = pure false := by
rw [anyM_eq_anyM_loop, anyM.loop, dif_neg (Nat.not_lt.2 h)]

theorem mem_def {a : α} {as : Array α} : a ∈ as ↔ a ∈ as.toList :=
fun | .mk h => h, Array.Mem.mk⟩

@[simp] theorem not_mem_empty (a : α) : ¬(a ∈ #[]) := by
simp [mem_def]

Expand Down Expand Up @@ -460,10 +483,6 @@ theorem lt_of_getElem {x : α} {a : Array α} {idx : Nat} {hidx : idx < a.size}
idx < a.size :=
hidx

@[simp] theorem getElem_mem {l : Array α} {i : Nat} (h : i < l.size) : l[i] ∈ l := by
erw [Array.mem_def, getElem_eq_getElem_toList]
apply List.get_mem

theorem getElem_fin_eq_getElem_toList (a : Array α) (i : Fin a.size) : a[i] = a.toList[i] := rfl

@[simp] theorem ugetElem_eq_getElem (a : Array α) {i : USize} (h : i.toNat < a.size) :
Expand Down Expand Up @@ -728,6 +747,11 @@ theorem getElem_range {n : Nat} {x : Nat} (h : x < (Array.range n).size) : (Arra
cases as
simp

@[simp] theorem forIn'_toList [Monad m] (as : Array α) (b : β) (f : (a : α) → a ∈ as.toList → β → m (ForInStep β)) :
forIn' as.toList b f = forIn' as b (fun a m b => f a (mem_toList.mpr m) b) := by
cases as
simp

/-! ### foldl / foldr -/

@[simp] theorem foldlM_loop_empty [Monad m] (f : β → α → m β) (init : β) (i j : Nat) :
Expand Down Expand Up @@ -1411,9 +1435,6 @@ namespace List
Our goal is to have `simp` "pull `List.toArray` outwards" as much as possible.
-/

@[simp] theorem mem_toArray {a : α} {l : List α} : a ∈ l.toArray ↔ a ∈ l := by
simp [mem_def]

@[simp] theorem toListRev_toArray (l : List α) : l.toArray.toListRev = l.reverse := by
simp

Expand Down
9 changes: 0 additions & 9 deletions src/Init/Data/Array/Mem.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,6 @@ import Init.Data.List.BasicAux

namespace Array

/-- `a ∈ as` is a predicate which asserts that `a` is in the array `as`. -/
-- NB: This is defined as a structure rather than a plain def so that a lemma
-- like `sizeOf_lt_of_mem` will not apply with no actual arrays around.
structure Mem (as : Array α) (a : α) : Prop where
val : a ∈ as.toList

instance : Membership α (Array α) where
mem := Mem

theorem sizeOf_lt_of_mem [SizeOf α] {as : Array α} (h : a ∈ as) : sizeOf a < sizeOf as := by
cases as with | _ as =>
exact Nat.lt_trans (List.sizeOf_lt_of_mem h.val) (by simp_arith)
Expand Down
2 changes: 2 additions & 0 deletions src/Init/Data/List/Control.lean
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ instance : ForIn m (List α) α where
instance : ForIn' m (List α) α inferInstance where
forIn' := List.forIn'

@[simp] theorem forIn'_eq_forIn' [Monad m] : @List.forIn' α β m _ = forIn' := rfl

@[simp] theorem forIn'_eq_forIn {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : List α) (init : β) (f : α → β → m (ForInStep β)) : forIn' as init (fun a _ b => f a b) = forIn as init f := by
simp [forIn', forIn, List.forIn, List.forIn']
have : ∀ cs h, List.forIn'.loop cs (fun a _ b => f a b) as init h = List.forIn.loop f as init := by
Expand Down
4 changes: 0 additions & 4 deletions src/Init/Data/List/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,6 @@ theorem getElem?_of_mem {a} {l : List α} (h : a ∈ l) : ∃ n : Nat, l[n]? = s
theorem get?_of_mem {a} {l : List α} (h : a ∈ l) : ∃ n, l.get? n = some a :=
let ⟨⟨n, _⟩, e⟩ := get_of_mem h; ⟨n, e ▸ get?_eq_get _⟩

@[simp] theorem getElem_mem : ∀ {l : List α} {n} (h : n < l.length), l[n]'h ∈ l
| _ :: _, 0, _ => .head ..
| _ :: l, _+1, _ => .tail _ (getElem_mem (l := l) ..)

theorem get_mem : ∀ (l : List α) n h, get l ⟨n, h⟩ ∈ l
| _ :: _, 0, _ => .head ..
| _ :: l, _+1, _ => .tail _ (get_mem l ..)
Expand Down
62 changes: 62 additions & 0 deletions src/Init/Data/List/Monadic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,68 @@ theorem mapM_eq_reverse_foldlM_cons [Monad m] [LawfulMonad m] (f : α → m β)
(l₁ ++ l₂).forM f = (do l₁.forM f; l₂.forM f) := by
induction l₁ <;> simp [*]

/-! ### forIn' -/

@[simp] theorem forIn'_nil [Monad m] (f : (a : α) → a ∈ [] → β → m (ForInStep β)) (b : β) : forIn' [] b f = pure b :=
rfl

theorem forIn'_loop_congr [Monad m] {as bs : List α}
{f : (a' : α) → a' ∈ as → β → m (ForInStep β)}
{g : (a' : α) → a' ∈ bs → β → m (ForInStep β)}
{b : β} (ha : ∃ ys, ys ++ xs = as) (hb : ∃ ys, ys ++ xs = bs)
(h : ∀ a m m' b, f a m b = g a m' b) : forIn'.loop as f xs b ha = forIn'.loop bs g xs b hb := by
induction xs generalizing b with
| nil => simp [forIn'.loop]
| cons a xs ih =>
simp only [forIn'.loop] at *
congr 1
· rw [h]
· funext s
obtain b | b := s
· rfl
· simp
rw [ih]

@[simp] theorem forIn'_cons [Monad m] {a : α} {as : List α}
(f : (a' : α) → a' ∈ a :: as → β → m (ForInStep β)) (b : β) :
forIn' (a::as) b f = f a (mem_cons_self a as) b >>=
fun | ForInStep.done b => pure b | ForInStep.yield b => forIn' as b fun a' m b => f a' (mem_cons_of_mem a m) b := by
simp only [forIn', List.forIn', forIn'.loop]
congr 1
funext s
obtain b | b := s
· rfl
· apply forIn'_loop_congr
intros
rfl

@[congr] theorem forIn'_congr [Monad m] {as bs : List α} (w : as = bs)
{b b' : β} (hb : b = b')
{f : (a' : α) → a' ∈ as → β → m (ForInStep β)}
{g : (a' : α) → a' ∈ bs → β → m (ForInStep β)}
(h : ∀ a m b, f a (by simpa [w] using m) b = g a m b) :
forIn' as b f = forIn' bs b' g := by
induction bs generalizing as b b' with
| nil =>
subst w
simp [hb, forIn'_nil]
| cons b bs ih =>
cases as with
| nil => simp at w
| cons a as =>
simp only [cons.injEq] at w
obtain ⟨rfl, rfl⟩ := w
simp only [forIn'_cons]
congr 1
· simp [h, hb]
· funext s
obtain b | b := s
· rfl
· simp
rw [ih rfl rfl]
intro a m b
exact h a (mem_cons_of_mem _ m) b

/-! ### allM -/

theorem allM_eq_not_anyM_not [Monad m] [LawfulMonad m] (p : α → m Bool) (as : List α) :
Expand Down
24 changes: 24 additions & 0 deletions src/Init/Data/Option/List.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,28 @@ namespace Option
@[simp] theorem mem_toList {a : α} {o : Option α} : a ∈ o.toList ↔ a ∈ o := by
cases o <;> simp [eq_comm]

@[simp] theorem forIn'_none [Monad m] (b : β) (f : (a : α) → a ∈ none → β → m (ForInStep β)) :
forIn' none b f = pure b := by
rfl

@[simp] theorem forIn'_some [Monad m] (a : α) (b : β) (f : (a' : α) → a' ∈ some a → β → m (ForInStep β)) :
forIn' (some a) b f = bind (f a rfl b) (fun | .done r | .yield r => pure r) := by
rfl

@[simp] theorem forIn_none [Monad m] (b : β) (f : α → β → m (ForInStep β)) :
forIn none b f = pure b := by
rfl

@[simp] theorem forIn_some [Monad m] (a : α) (b : β) (f : α → β → m (ForInStep β)) :
forIn (some a) b f = bind (f a b) (fun | .done r | .yield r => pure r) := by
rfl

@[simp] theorem forIn'_toList [Monad m] (o : Option α) (b : β) (f : (a : α) → a ∈ o.toList → β → m (ForInStep β)) :
forIn' o.toList b f = forIn' o b fun a m b => f a (by simpa using m) b := by
cases o <;> rfl

@[simp] theorem forIn_toList [Monad m] (o : Option α) (b : β) (f : α → β → m (ForInStep β)) :
forIn o.toList b f = forIn o b f := by
cases o <;> rfl

end Option
4 changes: 4 additions & 0 deletions src/Init/GetElem.lean
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ instance : GetElem (List α) Nat α fun as i => i < as.length where

@[deprecated (since := "2024-06-12")] abbrev cons_getElem_succ := @getElem_cons_succ

@[simp] theorem getElem_mem : ∀ {l : List α} {n} (h : n < l.length), l[n]'h ∈ l
| _ :: _, 0, _ => .head ..
| _ :: l, _+1, _ => .tail _ (getElem_mem (l := l) ..)

theorem get_drop_eq_drop (as : List α) (i : Nat) (h : i < as.length) : as[i] :: as.drop (i+1) = as.drop i :=
match as, i with
| _::_, 0 => rfl
Expand Down
8 changes: 8 additions & 0 deletions tests/lean/run/array_simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,11 @@ attribute [local simp] Id.run in
for i in [1,2,3,4].toArray do
s := s + i
pure s) ~> 10

attribute [local simp] Id.run in
#check_simp
(Id.run do
let mut s := 0
for h : i in [1,2,3,4].toArray do
s := s + i
pure s) ~> 10
2 changes: 1 addition & 1 deletion tests/lean/run/treeNode.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def treeToList (t : TreeNode) : List String :=
return r

@[simp] theorem treeToList_eq (name : String) (children : List TreeNode) : treeToList (.mkNode name children) = name :: List.join (children.map treeToList) := by
simp [treeToList, Id.run, forIn, List.forIn]
simp only [treeToList, Id.run, Id.pure_eq, Id.bind_eq, List.forIn'_eq_forIn, forIn, List.forIn]
have : ∀ acc, (Id.run do List.forIn.loop (fun a b => ForInStep.yield (b ++ treeToList a)) children acc) = acc ++ List.join (List.map treeToList children) := by
intro acc
induction children generalizing acc with simp [List.forIn.loop, List.map, List.join, Id.run]
Expand Down

0 comments on commit 07ea626

Please sign in to comment.