Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add BitVec.(getMsbD, msb)_(rotateLeft, RotateRight) #34

Closed
wants to merge 10 commits into from
97 changes: 96 additions & 1 deletion src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2611,7 +2611,6 @@ theorem getLsbD_rotateLeftAux_of_geq {x : BitVec w} {r : Nat} {i : Nat} (hi : i
apply getLsbD_ge
omega

/-- When `r < w`, we give a formula for `(x.rotateRight r).getLsbD i`. -/
theorem getLsbD_rotateLeft_of_le {x : BitVec w} {r i : Nat} (hr: r < w) :
(x.rotateLeft r).getLsbD i =
cond (i < r)
Expand All @@ -2638,6 +2637,52 @@ theorem getElem_rotateLeft {x : BitVec w} {r i : Nat} (h : i < w) :
if h' : i < r % w then x[(w - (r % w) + i)] else x[i - (r % w)] := by
simp [← BitVec.getLsbD_eq_getElem, h]

/-- If `w ≤ x < 2 * w`, then `x % w = x - w` -/
private theorem add_mod_eq_add_sub {x w : Nat} (x_le : w ≤ x) (x_lt : x < 2 * w) :
x % w = x - w := by
rw [Nat.mod_eq_sub_mod, Nat.mod_eq_of_lt (by omega)]
omega

theorem getMsbD_rotateLeftAux_of_le {x : BitVec w} {r : Nat} {i : Nat} (hi : i < w - r) :
(x.rotateLeftAux r).getMsbD i = x.getMsbD (r + i) := by
rw [rotateLeftAux, getMsbD_or]
simp [show i < w - r by omega, Nat.add_comm]

theorem getMsbD_rotateLeftAux_of_geq {x : BitVec w} {r : Nat} {i : Nat} (hi : i ≥ w - r) :
(x.rotateLeftAux r).getMsbD i = (decide (i < w) && x.getMsbD (i - (w - r))) := by
simp [rotateLeftAux, getMsbD_or, show i + r ≥ w by omega, show ¬i < w - r by omega]

/-- When `r < w`, we give a formula for `(x.rotateLeft r).getMsbD i`. -/
theorem getMsbD_rotateLeft_of_le {n w : Nat} {x : BitVec w} (hi : r < w):
(x.rotateLeft r).getMsbD n = (decide (n < w) && x.getMsbD ((r + n) % w)) := by
rcases w with rfl | w
· simp
· rw [BitVec.rotateLeft_eq_rotateLeftAux_of_lt (by omega)]
by_cases h : n < (w + 1) - r
· simp [getMsbD_rotateLeftAux_of_le h, Nat.mod_eq_of_lt, show r + n < (w + 1) by omega, show n < w + 1 by omega]
· simp [getMsbD_rotateLeftAux_of_geq <| Nat.ge_of_not_lt h]
by_cases h₁ : n < w + 1
· simp only [h₁, decide_True, Bool.true_and]
have h₂ : (r + n) < 2 * (w + 1) := by omega
rw [add_mod_eq_add_sub (by omega) (by omega)]
congr 1
omega
· simp [h₁]

theorem getMsbD_rotateLeft {r n w : Nat} {x : BitVec w} :
(x.rotateLeft r).getMsbD n = (decide (n < w) && x.getMsbD ((r + n) % w)) := by
rcases w with rfl | w
luisacicolini marked this conversation as resolved.
Show resolved Hide resolved
· simp
· by_cases h : r < w
· rw [getMsbD_rotateLeft_of_le (by omega)]
· rw [← rotateLeft_mod_eq_rotateLeft, getMsbD_rotateLeft_of_le (by apply Nat.mod_lt; simp)]
simp

@[simp]
theorem msb_rotateLeft {m w : Nat} {x : BitVec w} :
(x.rotateLeft m).msb = decide (0 < w && x.getMsbD (m % w)) := by
simp [BitVec.msb, getMsbD_rotateLeft]

/-! ## Rotate Right -/

/--
Expand Down Expand Up @@ -2725,6 +2770,56 @@ theorem getElem_rotateRight {x : BitVec w} {r i : Nat} (h : i < w) :
simp only [← BitVec.getLsbD_eq_getElem]
simp [getLsbD_rotateRight, h]

theorem getMsbD_rotateRightAux_of_le {x : BitVec w} {r : Nat} {i : Nat} (hi : i < r) :
(x.rotateRightAux r).getMsbD i = x.getMsbD (i + (w - r)) := by
rw [rotateRightAux, getMsbD_or, getMsbD_ushiftRight]
simp [show i < r by omega]

theorem getMsbD_rotateRightAux_of_geq {x : BitVec w} {r : Nat} {i : Nat} (hi : i ≥ r) :
(x.rotateRightAux r).getMsbD i = (decide (i < w) && x.getMsbD (i - r)) := by
simp [rotateRightAux, show ¬ i < r by omega, show i + (w - r) ≥ w by omega]

/-- When `m < w`, we give a formula for `(x.rotateLeft m).getMsbD i`. -/
@[simp]
theorem getMsbD_rotateRight_of_le {w n m : Nat} {x : BitVec w} (hr : m < w):
(x.rotateRight m).getMsbD n = (decide (n < w) && (if (n < m % w)
then x.getMsbD ((w + n - m % w) % w) else x.getMsbD (n - m % w))):= by
rcases w with rfl | w
· simp
· rw [rotateRight_eq_rotateRightAux_of_lt (by omega)]
by_cases h : n < m
· simp only [getMsbD_rotateRightAux_of_le h, show n < w + 1 by omega, decide_True,
show m % (w + 1) = m by rw [Nat.mod_eq_of_lt hr], h, ↓reduceIte,
show (w + 1 + n - m) < (w + 1) by omega, Nat.mod_eq_of_lt, Bool.true_and]
congr 1
omega
· simp [h, getMsbD_rotateRightAux_of_geq <| Nat.ge_of_not_lt h]
by_cases h₁ : n < w + 1
· simp [h, h₁, decide_True, Bool.true_and, Nat.mod_eq_of_lt hr]
· simp [h₁]

@[simp]
theorem getMsbD_rotateRight {w n m : Nat} {x : BitVec w} :
(x.rotateRight m).getMsbD n = (decide (n < w) && (if (n < m % w)
then x.getMsbD ((w + n - m % w) % w) else x.getMsbD (n - m % w))):= by
rcases w with rfl | w
· simp
· by_cases h₀ : m < w
· rw [getMsbD_rotateRight_of_le (by omega)]
· rw [← rotateRight_mod_eq_rotateRight, getMsbD_rotateRight_of_le (by apply Nat.mod_lt; simp)]
simp

@[simp]
theorem msb_rotateRight {r w: Nat} {x : BitVec w} :
(x.rotateRight r).msb = x.getMsbD ((w - r % w) % w) := by
simp only [BitVec.msb, getMsbD_rotateRight]
by_cases h₀ : 0 < w
· simp only [h₀, decide_True, Nat.add_zero, Nat.zero_le, Nat.sub_eq_zero_of_le, Bool.true_and,
ite_eq_left_iff, Nat.not_lt, Nat.le_zero_eq]
intro h₁
simp [h₁]
· simp [show w = 0 by omega]

/- ## twoPow -/

theorem twoPow_eq (w : Nat) (i : Nat) : twoPow w i = 1#w <<< i := by
Expand Down
Loading