diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 139d07705068..cca0c06a9f9f 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -2611,7 +2611,6 @@ theorem getLsbD_rotateLeftAux_of_geq {x : BitVec w} {r : Nat} {i : Nat} (hi : i apply getLsbD_ge omega -/-- When `r < w`, we give a formula for `(x.rotateRight r).getLsbD i`. -/ theorem getLsbD_rotateLeft_of_le {x : BitVec w} {r i : Nat} (hr: r < w) : (x.rotateLeft r).getLsbD i = cond (i < r) @@ -2638,6 +2637,52 @@ theorem getElem_rotateLeft {x : BitVec w} {r i : Nat} (h : i < w) : if h' : i < r % w then x[(w - (r % w) + i)] else x[i - (r % w)] := by simp [← BitVec.getLsbD_eq_getElem, h] +/-- If `w ≤ x < 2 * w`, then `x % w = x - w` -/ +private theorem add_mod_eq_add_sub {x w : Nat} (x_le : w ≤ x) (x_lt : x < 2 * w) : + x % w = x - w := by + rw [Nat.mod_eq_sub_mod, Nat.mod_eq_of_lt (by omega)] + omega + +theorem getMsbD_rotateLeftAux_of_le {x : BitVec w} {r : Nat} {i : Nat} (hi : i < w - r) : + (x.rotateLeftAux r).getMsbD i = x.getMsbD (r + i) := by + rw [rotateLeftAux, getMsbD_or] + simp [show i < w - r by omega, Nat.add_comm] + +theorem getMsbD_rotateLeftAux_of_geq {x : BitVec w} {r : Nat} {i : Nat} (hi : i ≥ w - r) : + (x.rotateLeftAux r).getMsbD i = (decide (i < w) && x.getMsbD (i - (w - r))) := by + simp [rotateLeftAux, getMsbD_or, show i + r ≥ w by omega, show ¬i < w - r by omega] + +/-- When `r < w`, we give a formula for `(x.rotateLeft r).getMsbD i`. -/ +theorem getMsbD_rotateLeft_of_le {n w : Nat} {x : BitVec w} (hi : r < w): + (x.rotateLeft r).getMsbD n = (decide (n < w) && x.getMsbD ((r + n) % w)) := by + rcases w with rfl | w + · simp + · rw [BitVec.rotateLeft_eq_rotateLeftAux_of_lt (by omega)] + by_cases h : n < (w + 1) - r + · simp [getMsbD_rotateLeftAux_of_le h, Nat.mod_eq_of_lt, show r + n < (w + 1) by omega, show n < w + 1 by omega] + · simp [getMsbD_rotateLeftAux_of_geq <| Nat.ge_of_not_lt h] + by_cases h₁ : n < w + 1 + · simp only [h₁, decide_True, Bool.true_and] + have h₂ : (r + n) < 2 * (w + 1) := by omega + rw [add_mod_eq_add_sub (by omega) (by omega)] + congr 1 + omega + · simp [h₁] + +theorem getMsbD_rotateLeft {r n w : Nat} {x : BitVec w} : + (x.rotateLeft r).getMsbD n = (decide (n < w) && x.getMsbD ((r + n) % w)) := by + rcases w with rfl | w + · simp + · by_cases h : r < w + · rw [getMsbD_rotateLeft_of_le (by omega)] + · rw [← rotateLeft_mod_eq_rotateLeft, getMsbD_rotateLeft_of_le (by apply Nat.mod_lt; simp)] + simp + +@[simp] +theorem msb_rotateLeft {m w : Nat} {x : BitVec w} : + (x.rotateLeft m).msb = decide (0 < w && x.getMsbD (m % w)) := by + simp [BitVec.msb, getMsbD_rotateLeft] + /-! ## Rotate Right -/ /-- @@ -2725,6 +2770,56 @@ theorem getElem_rotateRight {x : BitVec w} {r i : Nat} (h : i < w) : simp only [← BitVec.getLsbD_eq_getElem] simp [getLsbD_rotateRight, h] +theorem getMsbD_rotateRightAux_of_le {x : BitVec w} {r : Nat} {i : Nat} (hi : i < r) : + (x.rotateRightAux r).getMsbD i = x.getMsbD (i + (w - r)) := by + rw [rotateRightAux, getMsbD_or, getMsbD_ushiftRight] + simp [show i < r by omega] + +theorem getMsbD_rotateRightAux_of_geq {x : BitVec w} {r : Nat} {i : Nat} (hi : i ≥ r) : + (x.rotateRightAux r).getMsbD i = (decide (i < w) && x.getMsbD (i - r)) := by + simp [rotateRightAux, show ¬ i < r by omega, show i + (w - r) ≥ w by omega] + +/-- When `m < w`, we give a formula for `(x.rotateLeft m).getMsbD i`. -/ +@[simp] +theorem getMsbD_rotateRight_of_le {w n m : Nat} {x : BitVec w} (hr : m < w): + (x.rotateRight m).getMsbD n = (decide (n < w) && (if (n < m % w) + then x.getMsbD ((w + n - m % w) % w) else x.getMsbD (n - m % w))):= by + rcases w with rfl | w + · simp + · rw [rotateRight_eq_rotateRightAux_of_lt (by omega)] + by_cases h : n < m + · simp only [getMsbD_rotateRightAux_of_le h, show n < w + 1 by omega, decide_True, + show m % (w + 1) = m by rw [Nat.mod_eq_of_lt hr], h, ↓reduceIte, + show (w + 1 + n - m) < (w + 1) by omega, Nat.mod_eq_of_lt, Bool.true_and] + congr 1 + omega + · simp [h, getMsbD_rotateRightAux_of_geq <| Nat.ge_of_not_lt h] + by_cases h₁ : n < w + 1 + · simp [h, h₁, decide_True, Bool.true_and, Nat.mod_eq_of_lt hr] + · simp [h₁] + +@[simp] +theorem getMsbD_rotateRight {w n m : Nat} {x : BitVec w} : + (x.rotateRight m).getMsbD n = (decide (n < w) && (if (n < m % w) + then x.getMsbD ((w + n - m % w) % w) else x.getMsbD (n - m % w))):= by + rcases w with rfl | w + · simp + · by_cases h₀ : m < w + · rw [getMsbD_rotateRight_of_le (by omega)] + · rw [← rotateRight_mod_eq_rotateRight, getMsbD_rotateRight_of_le (by apply Nat.mod_lt; simp)] + simp + +@[simp] +theorem msb_rotateRight {r w: Nat} {x : BitVec w} : + (x.rotateRight r).msb = x.getMsbD ((w - r % w) % w) := by + simp only [BitVec.msb, getMsbD_rotateRight] + by_cases h₀ : 0 < w + · simp only [h₀, decide_True, Nat.add_zero, Nat.zero_le, Nat.sub_eq_zero_of_le, Bool.true_and, + ite_eq_left_iff, Nat.not_lt, Nat.le_zero_eq] + intro h₁ + simp [h₁] + · simp [show w = 0 by omega] + /- ## twoPow -/ theorem twoPow_eq (w : Nat) (i : Nat) : twoPow w i = 1#w <<< i := by