diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 283d4714376c..c7a68b5e0ad9 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -575,6 +575,10 @@ theorem zeroExtend_eq_setWidth {v : Nat} {x : BitVec w} : (x.setWidth v).toInt = Int.bmod x.toNat (2^v) := by simp [toInt_eq_toNat_bmod, toNat_setWidth, Int.emod_bmod] +@[simp] theorem toFin_setWidth (x : BitVec w) : + (x.setWidth v).toFin = Fin.ofNat' (2^v) x.toNat := by + ext; simp + theorem setWidth'_eq {x : BitVec w} (h : w ≤ v) : x.setWidth' h = x.setWidth v := by apply eq_of_toNat_eq rw [toNat_setWidth, toNat_setWidth'] @@ -651,6 +655,21 @@ theorem getElem?_setWidth (m : Nat) (x : BitVec n) (i : Nat) : getLsbD (setWidth m x) i = (decide (i < m) && getLsbD x i) := by simp [getLsbD, toNat_setWidth, Nat.testBit_mod_two_pow] +@[simp] theorem getMsbD_setWidth {m : Nat} {x : BitVec n} {i : Nat} : + getMsbD (setWidth m x) i = (decide (m - n ≤ i) && getMsbD x (i + n - m)) := by + unfold setWidth + by_cases h : n ≤ m <;> simp only [h] + · by_cases h' : (m - n ≤ i) + · simp [h', show (i - (m - n)) = i + n - m by omega] + · simp [h'] + · simp only [show m-n = 0 by omega, getMsbD, getLsbD_setWidth] + by_cases h'' : i < m + · simp [show m - 1 - i < m by omega, show i + n - m < n by omega, + show (n - 1 - (i + n - m)) = m - 1 - i by omega] + omega + · simp [h''] + omega + @[simp] theorem getMsbD_setWidth_add {x : BitVec w} (h : k ≤ i) : (x.setWidth (w + k)).getMsbD i = x.getMsbD (i - k) := by by_cases h : w = 0 @@ -1627,7 +1646,7 @@ private theorem Int.negSucc_emod (m : Nat) (n : Int) : -(m + 1) % n = Int.subNatNat (Int.natAbs n) ((m % Int.natAbs n) + 1) := rfl /-- The sign extension is the same as zero extending when `msb = false`. -/ -theorem signExtend_eq_not_setWidth_not_of_msb_false {x : BitVec w} {v : Nat} (hmsb : x.msb = false) : +theorem signExtend_eq_setWidth_of_msb_false {x : BitVec w} {v : Nat} (hmsb : x.msb = false) : x.signExtend v = x.setWidth v := by ext i by_cases hv : i < v @@ -1663,7 +1682,7 @@ theorem signExtend_eq_not_setWidth_not_of_msb_true {x : BitVec w} {v : Nat} (hms theorem getLsbD_signExtend (x : BitVec w) {v i : Nat} : (x.signExtend v).getLsbD i = (decide (i < v) && if i < w then x.getLsbD i else x.msb) := by rcases hmsb : x.msb with rfl | rfl - · rw [signExtend_eq_not_setWidth_not_of_msb_false hmsb] + · rw [signExtend_eq_setWidth_of_msb_false hmsb] by_cases (i < v) <;> by_cases (i < w) <;> simp_all <;> omega · rw [signExtend_eq_not_setWidth_not_of_msb_true hmsb] by_cases (i < v) <;> by_cases (i < w) <;> simp_all <;> omega @@ -1671,31 +1690,46 @@ theorem getLsbD_signExtend (x : BitVec w) {v i : Nat} : protected theorem Nat.sub_sub_comm {n m k : Nat} : n - m - k = n - k - m := sorry theorem getMsbD_signExtend (x : BitVec w) {v i : Nat} : - (x.signExtend v).getMsbD i = (decide (i < v) && if w+i-v < w then x.getMsbD (w+i-v) else x.msb) := by + (x.signExtend v).getMsbD i = (decide (i < v) && (if i > v-w then x.getMsbD (i-(v-w)) else x.msb)) := by rcases hmsb : x.msb with rfl | rfl - · rw [signExtend_eq_not_setWidth_not_of_msb_false hmsb] - simp_all [getMsbD] - by_cases h' : (i < v) <;> by_cases h'': ((w+i-v < w)) <;> simp [getMsbD, h', h''] - have h''': ((v - 1 - i < v)) := by omega - simp [h'''] - by_cases h5 : v ≤ w - · rw [show (w - 1 - (w + i - v)) = (v - 1 - i) by ( - rw [Nat.sub_add_comm] - rw [← Nat.sub_sub] - rw [Nat.sub_sub_comm (m := 1)] - rw [Nat.sub_sub_eq_min] - rw [Nat.min_eq_right] - omega; omega)] - · + · simp [signExtend_eq_setWidth_of_msb_false hmsb] + simp [getMsbD_setWidth'] + rw [signExtend_eq_not_setWidth_not_of_msb_false hmsb] + rw [getMsbD] + rw [getMsbD] + simp + by_cases h : i < v + · + simp [h] + simp [show v - 1 - i < v by omega] + by_cases h' : v - w < i + · simp [h'] + simp [show (i - (v - w) < w) by omega] + rw [show (w - 1 - (i - (v-w))) = (v - 1 - i) by omega] + · simp [h'] + omega + · simp [h] · rw [signExtend_eq_not_setWidth_not_of_msb_true hmsb] - by_cases h' : (i < v) <;> by_cases h'': ((w+i-v < w)) <;> simp [getMsbD, h', h''] - have h''': ((v - 1 - i < w)) := by omega - simp[h'''] + rw [getMsbD] + rw [getMsbD] + simp + by_cases h : i < v + · + simp [h] + simp [show v - 1 - i < v by omega] + by_cases h' : v - w < i + · simp [h'] + simp [show (i - (v - w) < w) by omega] + simp [show (v - 1 - i < w) by omega] + rw [show (w - 1 - (i - (v-w))) = (v - 1 - i) by omega] + · simp [h'] + rw [getlsbD_of] - <;> rw [show (w - 1 - (w - v + i)) = (v - 1 - i) by omega] + omega + · simp [h] theorem getElem_signExtend {x : BitVec w} {v i : Nat} (h : i < v) : (x.signExtend v)[i] = if i < w then x.getLsbD i else x.msb := by