diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index 9566fe2ec8bc..5ee468579a2d 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -534,6 +534,8 @@ def sshiftRight (a : BitVec n) (s : Nat) : BitVec n := .ofInt n (a.toInt >>> s) instance {n} : HShiftLeft (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x <<< y.toNat⟩ instance {n} : HShiftRight (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x >>> y.toNat⟩ +def sshiftRight' (a : BitVec n) (s : BitVec m) : BitVec n := a.sshiftRight s.toNat + /-- Auxiliary function for `rotateLeft`, which does not take into account the case where the rotation amount is greater than the bitvector width. -/ def rotateLeftAux (x : BitVec w) (n : Nat) : BitVec w := diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean index ffc1b7d79f81..385f01c2990c 100644 --- a/src/Init/Data/BitVec/Bitblast.lean +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -403,12 +403,8 @@ theorem shiftLeftRec_eq {x : BitVec w₁} {y : BitVec w₂} {n : Nat} : induction n generalizing x y case zero => ext i - simp only [shiftLeftRec_zero, twoPow_zero, Nat.reduceAdd, truncate_one] - suffices (y &&& 1#w₂) = zeroExtend w₂ (ofBool (y.getLsb 0)) by simp [this] - ext i - by_cases h : (↑i : Nat) = 0 - · simp [h, Bool.and_comm] - · simp [h]; omega + simp only [shiftLeftRec_zero, twoPow_zero, Nat.reduceAdd, truncate_one, + and_one_eq_zeroExtend_ofBool_getLsb] case succ n ih => simp only [shiftLeftRec_succ, and_twoPow] rw [ih] @@ -431,4 +427,103 @@ theorem shiftLeft_eq_shiftLeftRec (x : BitVec w₁) (y : BitVec w₂) : · simp [of_length_zero] · simp [shiftLeftRec_eq] +/- ### Logical shift right (ushiftRight) recurrence for bitblasting -/ + +def ushiftRight_rec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ := + let shiftAmt := (y &&& (twoPow w₂ n)) + match n with + | 0 => x >>> shiftAmt + | n + 1 => (ushiftRight_rec x y n) >>> shiftAmt + +@[simp] +theorem ushiftRight_rec_zero (x : BitVec w₁) (y : BitVec w₂) : + ushiftRight_rec x y 0 = x >>> (y &&& twoPow w₂ 0) := by + simp [ushiftRight_rec] + +@[simp] +theorem ushiftRight_rec_succ (x : BitVec w₁) (y : BitVec w₂) : + ushiftRight_rec x y (n + 1) = + (ushiftRight_rec x y n) >>> (y &&& twoPow w₂ (n + 1)) := by + simp [ushiftRight_rec] + +theorem ushiftRight'_ushiftRight' {x y z : BitVec w} : + x >>> y >>> z = x >>> (y.toNat + z.toNat) := by + simp [shiftRight_add] + +theorem ushiftRight_or_eq_ushiftRight_ushiftRight_of_and_eq_zero {x : BitVec w₁} {y z : BitVec w₂} + (h : y &&& z = 0#w₂) : + x >>> (y ||| z) = x >>> y >>> z := by + simp [← add_eq_or_of_and_eq_zero _ _ h, toNat_add_of_and_eq_zero h, shiftRight_add] + +theorem getLsb_ushiftRight' (x : BitVec w₁) (y : BitVec w₂) (i : Nat) : + (x >>> y).getLsb i = x.getLsb (y.toNat + i) := by + simp [getLsb_ushiftRight] + +theorem ushiftRight_rec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : + ushiftRight_rec x y n = x >>> (y.truncate (n + 1)).zeroExtend w₂ := by + induction n generalizing x y + case zero => + ext i + simp only [ushiftRight_rec_zero, twoPow_zero, Nat.reduceAdd, + and_one_eq_zeroExtend_ofBool_getLsb, truncate_one] + case succ n ih => + simp only [ushiftRight_rec_succ, and_twoPow] + rw [ih] + by_cases h : y.getLsb (n + 1) <;> simp only [h, ↓reduceIte] + · rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true h] + rw [ushiftRight_or_eq_ushiftRight_ushiftRight_of_and_eq_zero] + simp + · simp [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1), h] + +theorem shiftRight_eq_shiftRight_rec (x : BitVec w₁) (y : BitVec w₂) : + x >>> y = ushiftRight_rec x y (w₂ - 1) := by + rcases w₂ with rfl | w₂ + · simp [of_length_zero] + · simp [ushiftRight_rec_eq] + +/- ### Arithmetic shift right (sshiftRight) recurrence -/ + +def sshiftRightRec (x : BitVec w) (y : BitVec w₂) (n : Nat) : BitVec w := + let shiftAmt := (y &&& (twoPow w₂ n)) + match n with + | 0 => x.sshiftRight' shiftAmt + | n + 1 => (sshiftRightRec x y n).sshiftRight' shiftAmt + +@[simp] +theorem sshiftRightRec_zero_eq (x : BitVec w) (y : BitVec w₂) : + sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by + simp only [sshiftRightRec, twoPow_zero] + +@[simp] +theorem sshiftRightRec_succ_eq (x : BitVec w) (y : BitVec w₂) (n : Nat) : + sshiftRightRec x y (n + 1) = (sshiftRightRec x y n).sshiftRight' (y &&& twoPow w₂ (n + 1)) := by + simp [sshiftRightRec] + +theorem sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero {x : BitVec w} {y z : BitVec w₂} + (h : y &&& z = 0#w₂) : + x.sshiftRight' (y ||| z) = (x.sshiftRight' y).sshiftRight' z := by + simp [sshiftRight', ← add_eq_or_of_and_eq_zero _ _ h, + toNat_add_of_and_eq_zero h, sshiftRight'_add] + +theorem sshiftRightRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : + sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by + induction n generalizing x y + case zero => + ext i + simp [ushiftRight_rec_zero, twoPow_zero, Nat.reduceAdd, and_one_eq_zeroExtend_ofBool_getLsb, + truncate_one] + case succ n ih => + simp + by_cases h : y.getLsb (n + 1) <;> simp [h] + · simp [ih, zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true h, + sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero] + · simp [ih, + zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1), h] + +theorem sshiftRight_eq_sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) : + (x.sshiftRight' y).getLsb i = (sshiftRightRec x y (w₂ - 1)).getLsb i := by + rcases w₂ with rfl | w₂ + · simp [of_length_zero] + · simp [sshiftRightRec_eq] + end BitVec diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index fc1a22bddc99..db6868a317a0 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -731,6 +731,19 @@ theorem getLsb_shiftLeft' {x : BitVec w₁} {y : BitVec w₂} {i : Nat} : getLsb (x >>> i) j = getLsb x (i+j) := by unfold getLsb ; simp +@[simp] +theorem ushiftRight_zero_eq (x : BitVec w) : x >>> 0 = x := by + simp [bv_toNat] + +/-! ### ushiftRight reductions from BitVec to Nat -/ + +@[simp] +theorem ushiftRight_eq' (x : BitVec w) (y : BitVec w₂) : + x >>> y = x >>> y.toNat := by rfl + +-- @[simp] +-- theorem ushiftRight'_zero_eq (x : BitVec w) : x >>> (0#w₂) = x := by simp + /-! ### sshiftRight -/ theorem sshiftRight_eq {x : BitVec n} {i : Nat} : @@ -795,6 +808,44 @@ theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) : Nat.not_lt, decide_eq_true_eq] omega +/-- The msb after arithmetic shifting right equals the original msb. -/ +theorem sshiftRight_msb_eq_msb {n : Nat} {x : BitVec w} : + (x.sshiftRight n).msb = x.msb := by + rw [msb_eq_getLsb_last, getLsb_sshiftRight, msb_eq_getLsb_last] + by_cases hw₀ : w = 0 + · simp [hw₀] + · simp only [show ¬(w ≤ w - 1) by omega, decide_False, Bool.not_false, Bool.true_and, + ite_eq_right_iff] + intros h + simp [show n = 0 by omega] + +-- TODO: convert this into sshiftRight_add +theorem sshiftRight_sshiftRight {x : BitVec w} {m n : Nat} : + (x.sshiftRight m).sshiftRight n = x.sshiftRight (m + n) := by + ext i + simp only [getLsb_sshiftRight] + simp only [Nat.add_assoc] + by_cases h₁ : w ≤ (i : Nat) + · simp [h₁] + · simp only [h₁, decide_False, Bool.not_false, Bool.true_and] + by_cases h₂ : n + ↑i < w + · simp [h₂] + · simp only [h₂, ↓reduceIte] + by_cases h₃ : m + (n + ↑i) < w + · simp [h₃] + omega + · simp [h₃] + apply sshiftRight_msb_eq_msb + + +/-! ### shiftRight reductions from BitVec to Nat -/ + +@[simp] +theorem sshiftRight'_zero (x : BitVec w) : + x.sshiftRight' (0#w₂) = x := by + ext i + simp [sshiftRight', getLsb_sshiftRight] + /-! ### signExtend -/ /-- Equation theorem for `Int.sub` when both arguments are `Int.ofNat` -/ @@ -929,6 +980,11 @@ theorem shiftRight_add {w : Nat} (x : BitVec w) (n m : Nat) : ext i simp [Nat.add_assoc n m i] +theorem sshiftRight'_add {x : BitVec w₁} {y : BitVec w₂} {z : BitVec w₃} : + x.sshiftRight (y.toNat + z.toNat) = (x.sshiftRight' y).sshiftRight' z := by + simp [sshiftRight', shiftRight_add, sshiftRight_sshiftRight] + + @[deprecated shiftRight_add (since := "2024-06-02")] theorem shiftRight_shiftRight {w : Nat} (x : BitVec w) (n m : Nat) : (x >>> n) >>> m = x >>> (n + m) := by @@ -1549,4 +1605,12 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true simp [hx] · by_cases hik' : k < i + 1 <;> simp [hik, hik'] <;> omega +/-- Bitwise `and` of `(x : BitVec w`) with `1#w` equals zero extending the `lsb` to `w`. -/ +theorem and_one_eq_zeroExtend_ofBool_getLsb {x : BitVec w} : + (x &&& 1#w) = zeroExtend w (ofBool (x.getLsb 0)) := by + ext i + simp only [getLsb_and, getLsb_one, getLsb_zeroExtend, Fin.is_lt, decide_True, getLsb_ofBool, + Bool.true_and] + by_cases h : (0 = (i : Nat)) <;> simp [h] <;> omega + end BitVec