Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: shiftRight bitblasting theorems #11

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Init/Data/BitVec/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,8 @@ def sshiftRight (a : BitVec n) (s : Nat) : BitVec n := .ofInt n (a.toInt >>> s)
instance {n} : HShiftLeft (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x <<< y.toNat⟩
instance {n} : HShiftRight (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x >>> y.toNat⟩

def sshiftRight' (a : BitVec n) (s : BitVec m) : BitVec n := a.sshiftRight s.toNat

/-- Auxiliary function for `rotateLeft`, which does not take into account the case where
the rotation amount is greater than the bitvector width. -/
def rotateLeftAux (x : BitVec w) (n : Nat) : BitVec w :=
Expand Down
106 changes: 100 additions & 6 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -403,12 +403,8 @@ theorem shiftLeftRec_eq {x : BitVec w₁} {y : BitVec w₂} {n : Nat} :
induction n generalizing x y
case zero =>
ext i
simp only [shiftLeftRec_zero, twoPow_zero, Nat.reduceAdd, truncate_one]
suffices (y &&& 1#w₂) = zeroExtend w₂ (ofBool (y.getLsb 0)) by simp [this]
ext i
by_cases h : (↑i : Nat) = 0
· simp [h, Bool.and_comm]
· simp [h]; omega
simp only [shiftLeftRec_zero, twoPow_zero, Nat.reduceAdd, truncate_one,
and_one_eq_zeroExtend_ofBool_getLsb]
case succ n ih =>
simp only [shiftLeftRec_succ, and_twoPow]
rw [ih]
Expand All @@ -431,4 +427,102 @@ theorem shiftLeft_eq_shiftLeftRec (x : BitVec w₁) (y : BitVec w₂) :
· simp [of_length_zero]
· simp [shiftLeftRec_eq]

/- ### Logical shift right (ushiftRight) recurrence for bitblasting -/

def ushiftRight_rec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ :=
let shiftAmt := (y &&& (twoPow w₂ n))
match n with
| 0 => x >>> shiftAmt
| n + 1 => (ushiftRight_rec x y n) >>> shiftAmt

@[simp]
theorem ushiftRight_rec_zero (x : BitVec w₁) (y : BitVec w₂) :
ushiftRight_rec x y 0 = x >>> (y &&& twoPow w₂ 0) := by
simp [ushiftRight_rec]

@[simp]
theorem ushiftRight_rec_succ (x : BitVec w₁) (y : BitVec w₂) :
ushiftRight_rec x y (n + 1) =
(ushiftRight_rec x y n) >>> (y &&& twoPow w₂ (n + 1)) := by
simp [ushiftRight_rec]

theorem ushiftRight'_ushiftRight' {x y z : BitVec w} :
x >>> y >>> z = x >>> (y.toNat + z.toNat) := by
simp [shiftRight_add]

theorem ushiftRight_or_eq_ushiftRight_ushiftRight_of_and_eq_zero {x : BitVec w₁} {y z : BitVec w₂}
(h : y &&& z = 0#w₂) :
x >>> (y ||| z) = x >>> y >>> z := by
simp [← add_eq_or_of_and_eq_zero _ _ h, toNat_add_of_and_eq_zero h, shiftRight_add]

theorem getLsb_ushiftRight' (x : BitVec w₁) (y : BitVec w₂) (i : Nat) :
(x >>> y).getLsb i = x.getLsb (y.toNat + i) := by
simp [getLsb_ushiftRight]

theorem ushiftRight_rec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
ushiftRight_rec x y n = x >>> (y.truncate (n + 1)).zeroExtend w₂ := by
induction n generalizing x y
Comment on lines +463 to +464

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ushiftRight_rec x y n = x >>> (y.truncate (n + 1)).zeroExtend w₂ := by
induction n generalizing x y
ushiftRight_rec x y n = x >>> (y.truncate (n + 1)).zeroExtend w₂ := by
induction n generalizing x y

case zero =>
ext i
simp only [ushiftRight_rec_zero, twoPow_zero, Nat.reduceAdd,
and_one_eq_zeroExtend_ofBool_getLsb, truncate_one]
case succ n ih =>
simp only [ushiftRight_rec_succ, and_twoPow]
rw [ih]
by_cases h : y.getLsb (n + 1) <;> simp only [h, ↓reduceIte]
· rw [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true h]
rw [ushiftRight_or_eq_ushiftRight_ushiftRight_of_and_eq_zero]
simp
· simp [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1), h]

theorem shiftRight_eq_shiftRight_rec (x : BitVec w₁) (y : BitVec w₂) :
x >>> y = ushiftRight_rec x y (w₂ - 1) := by
rcases w₂ with rfl | w₂
· simp [of_length_zero]
· simp [ushiftRight_rec_eq]

/- ### Arithmetic shift right (sshiftRight) recurrence -/

def sshiftRightRec (x : BitVec w) (y : BitVec w₂) (n : Nat) : BitVec w :=
let shiftAmt := (y &&& (twoPow w₂ n))
Comment on lines +486 to +487

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def sshiftRightRec (x : BitVec w) (y : BitVec w₂) (n : Nat) : BitVec w :=
let shiftAmt := (y &&& (twoPow w₂ n))
def sshiftRightRec (x : BitVec w) (y : BitVec w₂) (n : Nat) : BitVec w :=
let shiftAmt := (y &&& (twoPow w₂ n))

match n with
| 0 => x.sshiftRight' shiftAmt
| n + 1 => (sshiftRightRec x y n).sshiftRight' shiftAmt

@[simp]
theorem sshiftRightRec_zero_eq (x : BitVec w) (y : BitVec w₂) :
sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by
Comment on lines +493 to +494

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
theorem sshiftRightRec_zero_eq (x : BitVec w) (y : BitVec w₂) :
sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by
theorem sshiftRightRec_zero_eq (x : BitVec w) (y : BitVec w₂) :
sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by

simp only [sshiftRightRec, twoPow_zero]

@[simp]
theorem sshiftRightRec_succ_eq (x : BitVec w) (y : BitVec w₂) (n : Nat) :
sshiftRightRec x y (n + 1) = (sshiftRightRec x y n).sshiftRight' (y &&& twoPow w₂ (n + 1)) := by
Comment on lines +498 to +499

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
theorem sshiftRightRec_succ_eq (x : BitVec w) (y : BitVec w₂) (n : Nat) :
sshiftRightRec x y (n + 1) = (sshiftRightRec x y n).sshiftRight' (y &&& twoPow w₂ (n + 1)) := by
theorem sshiftRightRec_succ_eq (x : BitVec w) (y : BitVec w₂) (n : Nat) :
sshiftRightRec x y (n + 1) = (sshiftRightRec x y n).sshiftRight' (y &&& twoPow w₂ (n + 1)) := by

simp [sshiftRightRec]

theorem sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero {x : BitVec w} {y z : BitVec w₂}
(h : y &&& z = 0#w₂) :
Comment on lines +502 to +503

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
theorem sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero {x : BitVec w} {y z : BitVec w₂}
(h : y &&& z = 0#w₂) :
theorem sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero {x : BitVec w} {y z : BitVec w₂}
(h : y &&& z = 0#w₂) :

x.sshiftRight' (y ||| z) = (x.sshiftRight' y).sshiftRight' z := by
simp [sshiftRight', ← add_eq_or_of_and_eq_zero _ _ h,
toNat_add_of_and_eq_zero h, sshiftRight'_add]
Comment on lines +505 to +506

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
simp [sshiftRight', ← add_eq_or_of_and_eq_zero _ _ h,
toNat_add_of_and_eq_zero h, sshiftRight'_add]
simp [sshiftRight', ← add_eq_or_of_and_eq_zero _ _ h,
toNat_add_of_and_eq_zero h, sshiftRight'_add]


theorem sshiftRightRec_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by
induction n generalizing x y
Comment on lines +509 to +510

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by
induction n generalizing x y
sshiftRightRec x y n = x.sshiftRight' ((y.truncate (n + 1)).zeroExtend w₂) := by
induction n generalizing x y

case zero =>
ext i
simp [ushiftRight_rec_zero, twoPow_zero, Nat.reduceAdd, and_one_eq_zeroExtend_ofBool_getLsb,
truncate_one]
case succ n ih =>
simp only [sshiftRightRec_succ_eq, and_twoPow, ih]
by_cases h : y.getLsb (n + 1) <;> simp [h]
· simp [zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true h,
sshiftRight_or_eq_sshiftRight_sshiftRight_of_and_eq_zero]
· simp [zeroExtend_truncate_succ_eq_zeroExtend_truncate_of_getLsb_false (i := n + 1), h]

theorem sshiftRight_eq_sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) :
(x.sshiftRight' y).getLsb i = (sshiftRightRec x y (w₂ - 1)).getLsb i := by
rcases w₂ with rfl | w₂
· simp [of_length_zero]
· simp [sshiftRightRec_eq]

end BitVec
56 changes: 56 additions & 0 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,16 @@ theorem getLsb_shiftLeft' {x : BitVec w₁} {y : BitVec w₂} {i : Nat} :
getLsb (x >>> i) j = getLsb x (i+j) := by
unfold getLsb ; simp

@[simp]
theorem ushiftRight_zero_eq (x : BitVec w) : x >>> 0 = x := by
simp [bv_toNat]

/-! ### ushiftRight reductions from BitVec to Nat -/

@[simp]
theorem ushiftRight_eq' (x : BitVec w) (y : BitVec w₂) :
x >>> y = x >>> y.toNat := by rfl

/-! ### sshiftRight -/

theorem sshiftRight_eq {x : BitVec n} {i : Nat} :
Expand Down Expand Up @@ -795,6 +805,40 @@ theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
Nat.not_lt, decide_eq_true_eq]
omega

/-- The msb after arithmetic shifting right equals the original msb. -/
theorem sshiftRight_msb_eq_msb {n : Nat} {x : BitVec w} :
(x.sshiftRight n).msb = x.msb := by
rw [msb_eq_getLsb_last, getLsb_sshiftRight, msb_eq_getLsb_last]
by_cases hw₀ : w = 0
· simp [hw₀]
· simp only [show ¬(w ≤ w - 1) by omega, decide_False, Bool.not_false, Bool.true_and,
ite_eq_right_iff]
intros h
simp [show n = 0 by omega]

theorem sshiftRight_add {x : BitVec w} {m n : Nat} :
x.sshiftRight (m + n) = (x.sshiftRight m).sshiftRight n := by
ext i
simp only [getLsb_sshiftRight, Nat.add_assoc]
by_cases h₁ : w ≤ (i : Nat)
· simp [h₁]
· simp only [h₁, decide_False, Bool.not_false, Bool.true_and]
by_cases h₂ : n + ↑i < w
· simp [h₂]
· simp only [h₂, ↓reduceIte]
by_cases h₃ : m + (n + ↑i) < w
· simp [h₃]
omega
· simp [h₃, sshiftRight_msb_eq_msb]

/-! ### shiftRight reductions from BitVec to Nat -/

@[simp]
theorem sshiftRight'_zero (x : BitVec w) :
x.sshiftRight' (0#w₂) = x := by
ext i
simp [sshiftRight', getLsb_sshiftRight]

/-! ### signExtend -/

/-- Equation theorem for `Int.sub` when both arguments are `Int.ofNat` -/
Expand Down Expand Up @@ -929,6 +973,10 @@ theorem shiftRight_add {w : Nat} (x : BitVec w) (n m : Nat) :
ext i
simp [Nat.add_assoc n m i]

theorem sshiftRight'_add {x : BitVec w₁} {y : BitVec w₂} {z : BitVec w₃} :
x.sshiftRight (y.toNat + z.toNat) = (x.sshiftRight' y).sshiftRight' z := by
simp [sshiftRight', shiftRight_add, sshiftRight_add]

@[deprecated shiftRight_add (since := "2024-06-02")]
theorem shiftRight_shiftRight {w : Nat} (x : BitVec w) (n m : Nat) :
(x >>> n) >>> m = x >>> (n + m) := by
Expand Down Expand Up @@ -1549,4 +1597,12 @@ theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_or_twoPow_of_getLsb_true
simp [hx]
· by_cases hik' : k < i + 1 <;> simp [hik, hik'] <;> omega

/-- Bitwise `and` of `(x : BitVec w`) with `1#w` equals zero extending the `lsb` to `w`. -/
theorem and_one_eq_zeroExtend_ofBool_getLsb {x : BitVec w} :
(x &&& 1#w) = zeroExtend w (ofBool (x.getLsb 0)) := by
ext i
simp only [getLsb_and, getLsb_one, getLsb_zeroExtend, Fin.is_lt, decide_True, getLsb_ofBool,
Bool.true_and]
by_cases h : (0 = (i : Nat)) <;> simp [h] <;> omega

end BitVec
Loading