diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 21523d5f8744..7e5de36317bf 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -498,6 +498,9 @@ theorem toInt_ofNat {n : Nat} (x : Nat) : @[simp] theorem ofInt_ofNat (w n : Nat) : BitVec.ofInt w (no_index (OfNat.ofNat n)) = BitVec.ofNat w (OfNat.ofNat n) := rfl +@[simp] theorem ofInt_toInt (x : BitVec w) : BitVec.ofInt w (x.toInt) = x := by + by_cases h : 2 * x.toNat < 2^w <;> ext <;> simp [getLsbD, h, BitVec.toInt] + theorem toInt_neg_iff {w : Nat} {x : BitVec w} : BitVec.toInt x < 0 ↔ 2 ^ w ≤ 2 * x.toNat := by simp [toInt_eq_toNat_cond]; omega @@ -520,6 +523,9 @@ theorem eq_zero_or_eq_one (a : BitVec 1) : a = 0#1 ∨ a = 1#1 := by theorem toInt_zero {w : Nat} : (0#w).toInt = 0 := by simp [BitVec.toInt, show 0 < 2^w by exact Nat.two_pow_pos w] +@[simp] theorem toInt_cast (h : w = v) (x : BitVec w) : (cast h x).toInt = x.toInt := by + simp [toInt_eq_toNat_bmod, h] + /-! ### slt -/ /-- @@ -569,6 +575,10 @@ theorem zeroExtend_eq_setWidth {v : Nat} {x : BitVec w} : (x.setWidth v).toInt = Int.bmod x.toNat (2^v) := by simp [toInt_eq_toNat_bmod, toNat_setWidth, Int.emod_bmod] +@[simp] theorem toFin_setWidth (x : BitVec w) : + (x.setWidth v).toFin = Fin.ofNat' (2^v) x.toNat := by + ext; simp + theorem setWidth'_eq {x : BitVec w} (h : w ≤ v) : x.setWidth' h = x.setWidth v := by apply eq_of_toNat_eq rw [toNat_setWidth, toNat_setWidth'] @@ -645,6 +655,21 @@ theorem getElem?_setWidth (m : Nat) (x : BitVec n) (i : Nat) : getLsbD (setWidth m x) i = (decide (i < m) && getLsbD x i) := by simp [getLsbD, toNat_setWidth, Nat.testBit_mod_two_pow] +theorem getMsbD_setWidth {m : Nat} {x : BitVec n} {i : Nat} : + getMsbD (setWidth m x) i = (decide (m - n ≤ i) && getMsbD x (i + n - m)) := by + unfold setWidth + by_cases h : n ≤ m <;> simp only [h] + · by_cases h' : (m - n ≤ i) + · simp [h', show (i - (m - n)) = i + n - m by omega] + · simp [h'] + · simp only [show m-n = 0 by omega, getMsbD, getLsbD_setWidth] + by_cases h'' : i < m + · simp [show m - 1 - i < m by omega, show i + n - m < n by omega, + show (n - 1 - (i + n - m)) = m - 1 - i by omega] + omega + · simp [h''] + omega + @[simp] theorem getMsbD_setWidth_add {x : BitVec w} (h : k ≤ i) : (x.setWidth (w + k)).getMsbD i = x.getMsbD (i - k) := by by_cases h : w = 0 @@ -1124,6 +1149,11 @@ theorem not_eq_comm {x y : BitVec w} : ~~~ x = y ↔ x = ~~~ y := by BitVec.toNat (x <<< n) = BitVec.toNat x <<< n % 2^v := BitVec.toNat_ofNat _ _ +@[simp] theorem toInt_shiftLeft {n : Nat} {x : BitVec v} : + BitVec.toInt (x <<< n) = Int.bmod (BitVec.toInt x * 2^n) (2^v) := by + simp [toInt_eq_toNat_bmod, Nat.shiftLeft_eq] + norm_cast + @[simp] theorem toFin_shiftLeft {n : Nat} (x : BitVec w) : BitVec.toFin (x <<< n) = Fin.ofNat' (2^w) (x.toNat <<< n) := rfl @@ -1223,6 +1253,7 @@ theorem shiftLeftZeroExtend_eq {x : BitVec w} : (shiftLeftZeroExtend x i).msb = x.msb := by simp [shiftLeftZeroExtend_eq, BitVec.msb] + theorem shiftLeft_add {w : Nat} (x : BitVec w) (n m : Nat) : x <<< (n + m) = (x <<< n) <<< m := by ext i @@ -1616,7 +1647,7 @@ private theorem Int.negSucc_emod (m : Nat) (n : Int) : -(m + 1) % n = Int.subNatNat (Int.natAbs n) ((m % Int.natAbs n) + 1) := rfl /-- The sign extension is the same as zero extending when `msb = false`. -/ -theorem signExtend_eq_not_setWidth_not_of_msb_false {x : BitVec w} {v : Nat} (hmsb : x.msb = false) : +theorem signExtend_eq_setWidth_of_msb_false {x : BitVec w} {v : Nat} (hmsb : x.msb = false) : x.signExtend v = x.setWidth v := by ext i by_cases hv : i < v @@ -1652,16 +1683,33 @@ theorem signExtend_eq_not_setWidth_not_of_msb_true {x : BitVec w} {v : Nat} (hms theorem getLsbD_signExtend (x : BitVec w) {v i : Nat} : (x.signExtend v).getLsbD i = (decide (i < v) && if i < w then x.getLsbD i else x.msb) := by rcases hmsb : x.msb with rfl | rfl - · rw [signExtend_eq_not_setWidth_not_of_msb_false hmsb] + · rw [signExtend_eq_setWidth_of_msb_false hmsb] by_cases (i < v) <;> by_cases (i < w) <;> simp_all <;> omega · rw [signExtend_eq_not_setWidth_not_of_msb_true hmsb] by_cases (i < v) <;> by_cases (i < w) <;> simp_all <;> omega +theorem getMsbD_signExtend {x : BitVec w} {v i : Nat} : + (x.signExtend v).getMsbD i = + (decide (i < v) && if v - w ≤ i then x.getMsbD (i + w - v) else x.msb) := by + rcases hmsb : x.msb with rfl | rfl + · simp [signExtend_eq_setWidth_of_msb_false hmsb, getMsbD_setWidth] + by_cases h : (v - w ≤ i) <;> simp [h, getMsbD] <;> omega + · simp only [signExtend_eq_not_setWidth_not_of_msb_true hmsb, getMsbD_not, + getMsbD_setWidth, Bool.not_and, Bool.not_not, Bool.if_true_right] + by_cases h : i < v <;> by_cases h' : v - w ≤ i <;> simp [h, h'] <;> omega + theorem getElem_signExtend {x : BitVec w} {v i : Nat} (h : i < v) : (x.signExtend v)[i] = if i < w then x.getLsbD i else x.msb := by rw [←getLsbD_eq_getElem, getLsbD_signExtend] simp [h] +theorem msb_SignExtend {x : BitVec w} : + (x.signExtend v).msb = (decide (0 < v) && if w ≥ v then x.getMsbD (w - v) else x.msb) := by + simp [BitVec.msb, getMsbD_signExtend] + by_cases h : w ≥ v + · simp [h, show v - w = 0 by omega] + · simp [h, show ¬ (v - w = 0) by omega] + /-- Sign extending to a width smaller than the starting width is a truncation. -/ theorem signExtend_eq_setWidth_of_lt (x : BitVec w) {v : Nat} (hv : v ≤ w): x.signExtend v = x.setWidth v := by @@ -1759,6 +1807,23 @@ theorem append_def (x : BitVec v) (y : BitVec w) : (x ++ y).toNat = x.toNat <<< n ||| y.toNat := rfl +/-- Helper theorem to show that the expression in `(x ++ y).toFin` is inbounds. -/ +theorem toNat_append_lt {m n : Nat} (x : BitVec m) (y : BitVec n) : + x.toNat <<< n ||| y.toNat < 2 ^ (m + n) := by + have hnLe : 2^n ≤ 2 ^(m + n) := by + rw [Nat.pow_add] + exact Nat.le_mul_of_pos_left (2 ^ n) (Nat.two_pow_pos m) + apply Nat.or_lt_two_pow + · have := Nat.two_pow_pos n + rw [Nat.shiftLeft_eq, Nat.pow_add, Nat.mul_lt_mul_right] + <;> omega + · omega + +@[simp] theorem toFin_append {x : BitVec m} {y : BitVec n} : + (x ++ y).toFin = @Fin.mk (2^(m+n)) (x.toNat <<< n ||| y.toNat) (toNat_append_lt x y) := by + ext + simp + theorem getLsbD_append {x : BitVec n} {y : BitVec m} : getLsbD (x ++ y) i = bif i < m then getLsbD y i else getLsbD x (i - m) := by simp only [append_def, getLsbD_or, getLsbD_shiftLeftZeroExtend, getLsbD_setWidth'] @@ -1806,6 +1871,173 @@ theorem msb_append {x : BitVec w} {y : BitVec v} : ext simp only [getLsbD_append, getLsbD_zero, Bool.cond_self] +theorem append_zero {n m : Nat} {x : BitVec n} : + x ++ 0#m = x.signExtend (n + m) <<< m := by + induction m + case zero => + simp [signExtend] + case succ i ih => + simp [bv_toNat] + sorry + +def lhs (x : BitVec n) (y : BitVec m) : Int := (x++y).toInt +def rhs (x : BitVec n) (y : BitVec m) : Int := if n == 0 then y.toInt else (x.toInt * (2^m)) + y.toNat + +def eq (x: BitVec n) (y: BitVec m) : Bool := (lhs x y) = (rhs x y) + + +#eval (-5#10 ++ 3#2).toInt + +def test : Bool := Id.run do + for i in [0, 1, 2, 3, 4, 5, 6, 7, 8] do + for j in [0, 1, 2, 3, 4, 5, 6, 7, 8] do + for n in [0, 1, 2, 3, 4] do + for m in [0, 1, 2, 3, 4] do + let x := BitVec.ofNat n i + let y := BitVec.ofNat m j + if (!eq x y) then + return false + return true + +private theorem Nat.lt_mul_of_le_of_lt_of_lt {a b c : Nat} (hab : a ≤ b) (ha : 0 < a) (hc : 1 < c) : + a < b * c := by + have : a * 1 < b * c := Nat.mul_lt_mul_of_le_of_lt' (by omega) (by simp [hc]) (by omega) + simp at this + simp [this] + +private theorem Nat.two_pow_lt_two_pow_add {n m : Nat} (h : m ≠ 0) : + 2 ^ n < 2 ^ (n + m) := by + apply Nat.pow_lt_pow_of_lt (by omega) (by omega) + +@[simp] theorem signExtend_shiftLeft_msb {n m : Nat} {x : BitVec n} : + (signExtend (n + m) x <<< m).msb = x.msb := by + induction m + case zero => + simp [signExtend] + case succ i ih => + rw [← ih] + rw [msb_setWidth] + + unfold BitVec.msb getMsbD + simp + by_cases h : (0 < n + i) + · + rw [← Nat.add_assoc] + simp [h] + have h' : (0 < n + i + 1) := by omega + have hh : (n + i - (i + 1)) = (n + i - i - 1) := by + omega + rw [hh] + simp + have hhh : (n + i - 1 - i) = (n + i - i - 1) := by omega + rw [hhh] + simp + rw [getLsbD_signExtend] + + + + + + + simp [BitVec.msb, getMsbD] + + by_cases h : 0 < n + (i + 1) + · simp [h] + + sorry + · simp [h] + sorry + +@[simp] theorem signExtend_toNat_shift_mod : + ((signExtend (n + m) x).toNat <<< m) % ↑(2 ^ (n + m)) = (signExtend (n + m) x).toNat <<< m := + sorry + +@[simp] theorem toInt_append_zero {n m : Nat} {x : BitVec n} : + (x ++ 0#m).toInt = x.toInt * (2 ^ m) := by + by_cases m0 : m = 0 + · subst m0 + simp + · simp only [ofNat_eq_ofNat, append_zero, toInt_eq_msb_cond] + by_cases h1 : (signExtend (n + m) x <<< m).msb + · by_cases h2: x.msb + · norm_cast + simp [h1, h2] + norm_cast + rw [Int.sub_mul, Nat.pow_add] + norm_cast + simp + rw [Nat.shiftLeft_eq] + norm_cast + have aa := @Nat.pow_pos 2 m (by omega) + norm_cast + have bb := @Nat.mul_right_cancel_iff (2^m) ((signExtend (n + m) x).toNat) + apply bb + rfl + rw [Nat.mul_right_cancel (m := 2 ^ m)] + simp [aa] + rw [Nat.mod_eq_of_lt (a := x.toNat) (by omega)] + norm_cast + simp [h3] + simp_all + rw [Nat.shiftLeft_eq] + · simp only [signExtend_shiftLeft_of_lt] at h1 + contradiction + · by_cases h2: x.msb + · simp [signExtend_shiftLeft_of_lt, h2] at h1 + · sorry + +@[simp] theorem toInt_append {x : BitVec n} {y : BitVec m} : + (x ++ y).toInt = if n == 0 then y.toInt else x.toInt * (2 ^ m) + y.toNat := by + by_cases n0 : n = 0 + · subst n0 + simp [BitVec.eq_nil x] + · by_cases m0 : m = 0 + · subst m0 + simp [BitVec.eq_nil y, n0] + · simp [m0] + by_cases y0 : y = 0 + · simp [toInt_append_zero, y0, n0] + rw [toInt_eq_toNat_cond] + rw [toInt_eq_toNat_cond] + split + · + split + <;> norm_cast + <;> simp + <;> rw [Nat.mod_eq_of_lt (a := x.toNat) (by omega)] + <;> norm_cast + <;> simp [h3] + <;> simp_all + · rw [Nat.shiftLeft_eq] + · rename_i aa bb + rw [Nat.shiftLeft_eq] at aa + rw [Nat.pow_add] at aa + rw [← Nat.mul_assoc] at aa + + sorry + · + split + <;> norm_cast + <;> simp + <;> rw [Nat.mod_eq_of_lt (a := x.toNat) (by omega)] + <;> norm_cast + <;> simp [h3] + <;> simp_all + · rename_i aa bb + rw [Nat.shiftLeft_eq] at aa + rw [Nat.pow_add] at aa + rw [← Nat.mul_assoc] at aa + + + + + simp_all + + + sorry + · simp [Nat.shiftLeft_eq, Int.sub_mul, Nat.pow_add] + · sorry + @[simp] theorem cast_append_right (h : w + v = w + v') (x : BitVec w) (y : BitVec v) : cast h (x ++ y) = x ++ cast (by omega) y := by ext