Skip to content

Commit

Permalink
feat: getLsb_sshiftRight (leanprover#4179)
Browse files Browse the repository at this point in the history
In the course of the development, I grabbed facts about right shifting
over integers [from
`mathlib4`](https://github.com/leanprover-community/mathlib4/blob/master/Mathlib/Data/Int/Bitwise.lean).

The core proof strategy is to perform a case analysis of the msb:
- If `msb = false`, then `sshiftRight = ushiftRight`.
- If `msb = true`. then `x >>>s i = ~~~(~~~(x >>>u i))`. The double
negation introduces the high `1` bits that one expects of the arithmetic
shift.

---------

Co-authored-by: Kim Morrison <[email protected]>
  • Loading branch information
bollu and kim-em authored Jun 1, 2024
1 parent 9a597ae commit 81f5b07
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 0 deletions.
68 changes: 68 additions & 0 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Init.Data.BitVec.Basic
import Init.Data.Fin.Lemmas
import Init.Data.Nat.Lemmas
import Init.Data.Nat.Mod
import Init.Data.Int.Bitwise.Lemmas

namespace BitVec

Expand Down Expand Up @@ -281,6 +282,9 @@ theorem toInt_ofNat {n : Nat} (x : Nat) :
have p : 0 ≤ i % (2^n : Nat) := by omega
simp [toInt_eq_toNat_bmod, Int.toNat_of_nonneg p]

@[simp] theorem ofInt_natCast (w n : Nat) :
BitVec.ofInt w (n : Int) = BitVec.ofNat w n := rfl

/-! ### zeroExtend and truncate -/

@[simp, bv_toNat] theorem toNat_zeroExtend' {m n : Nat} (p : m ≤ n) (x : BitVec m) :
Expand Down Expand Up @@ -658,6 +662,70 @@ theorem shiftLeft_shiftLeft {w : Nat} (x : BitVec w) (n m : Nat) :
getLsb (x >>> i) j = getLsb x (i+j) := by
unfold getLsb ; simp

/-! ### sshiftRight -/

theorem sshiftRight_eq {x : BitVec n} {i : Nat} :
x.sshiftRight i = BitVec.ofInt n (x.toInt >>> i) := by
apply BitVec.eq_of_toInt_eq
simp [BitVec.sshiftRight]

/-- if the msb is false, the arithmetic shift right equals logical shift right -/
theorem sshiftRight_eq_of_msb_false {x : BitVec w} {s : Nat} (h : x.msb = false) :
(x.sshiftRight s) = x >>> s := by
apply BitVec.eq_of_toNat_eq
rw [BitVec.sshiftRight_eq, BitVec.toInt_eq_toNat_cond]
have hxbound : 2 * x.toNat < 2 ^ w := (BitVec.msb_eq_false_iff_two_mul_lt x).mp h
simp only [hxbound, ↓reduceIte, Int.natCast_shiftRight, Int.ofNat_eq_coe, ofInt_natCast,
toNat_ofNat, toNat_ushiftRight]
replace hxbound : x.toNat >>> s < 2 ^ w := by
rw [Nat.shiftRight_eq_div_pow]
exact Nat.lt_of_le_of_lt (Nat.div_le_self ..) x.isLt
apply Nat.mod_eq_of_lt hxbound

/--
If the msb is `true`, the arithmetic shift right equals negating,
then logical shifting right, then negating again.
The double negation preserves the lower bits that have been shifted,
and the outer negation ensures that the high bits are '1'. -/
theorem sshiftRight_eq_of_msb_true {x : BitVec w} {s : Nat} (h : x.msb = true) :
(x.sshiftRight s) = ~~~((~~~x) >>> s) := by
apply BitVec.eq_of_toNat_eq
rcases w with rfl | w
· simp
· rw [BitVec.sshiftRight_eq, BitVec.toInt_eq_toNat_cond]
have hxbound : (2 * x.toNat ≥ 2 ^ (w + 1)) := (BitVec.msb_eq_true_iff_two_mul_ge x).mp h
replace hxbound : ¬ (2 * x.toNat < 2 ^ (w + 1)) := by omega
simp only [hxbound, ↓reduceIte, toNat_ofInt, toNat_not, toNat_ushiftRight]
rw [← Int.subNatNat_eq_coe, Int.subNatNat_of_lt (by omega),
Nat.pred_eq_sub_one, Int.negSucc_shiftRight,
Int.emod_negSucc, Int.natAbs_ofNat, Nat.succ_eq_add_one,
Int.subNatNat_of_le (by omega), Int.toNat_ofNat, Nat.mod_eq_of_lt,
Nat.sub_right_comm]
omega
· rw [Nat.shiftRight_eq_div_pow]
apply Nat.lt_of_le_of_lt (Nat.div_le_self _ _) (by omega)

theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
getLsb (x.sshiftRight s) i =
(!decide (w ≤ i) && if s + i < w then x.getLsb (s + i) else x.msb) := by
rcases hmsb : x.msb with rfl | rfl
· simp only [sshiftRight_eq_of_msb_false hmsb, getLsb_ushiftRight, Bool.if_false_right]
by_cases hi : i ≥ w
· simp only [hi, decide_True, Bool.not_true, Bool.false_and]
apply getLsb_ge
omega
· simp only [hi, decide_False, Bool.not_false, Bool.true_and, Bool.iff_and_self,
decide_eq_true_eq]
intros hlsb
apply BitVec.lt_of_getLsb _ _ hlsb
· by_cases hi : i ≥ w
· simp [hi]
· simp only [sshiftRight_eq_of_msb_true hmsb, getLsb_not, getLsb_ushiftRight, Bool.not_and,
Bool.not_not, hi, decide_False, Bool.not_false, Bool.if_true_right, Bool.true_and,
Bool.and_iff_right_iff_imp, Bool.or_eq_true, Bool.not_eq_true', decide_eq_false_iff_not,
Nat.not_lt, decide_eq_true_eq]
omega

/-! ### append -/

theorem append_def (x : BitVec v) (y : BitVec w) :
Expand Down
37 changes: 37 additions & 0 deletions src/Init/Data/Int/Bitwise/Lemmas.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/-
Copyright (c) 2023 Siddharth Bhat. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Siddharth Bhat, Jeremy Avigad
-/
prelude
import Init.Data.Nat.Bitwise.Lemmas
import Init.Data.Int.Bitwise

namespace Int

theorem shiftRight_eq (n : Int) (s : Nat) : n >>> s = Int.shiftRight n s := rfl
@[simp]
theorem natCast_shiftRight (n s : Nat) : (n : Int) >>> s = n >>> s := rfl

@[simp]
theorem negSucc_shiftRight (m n : Nat) :
-[m+1] >>> n = -[m >>>n +1] := rfl

theorem shiftRight_add (i : Int) (m n : Nat) :
i >>> (m + n) = i >>> m >>> n := by
simp only [shiftRight_eq, Int.shiftRight]
cases i <;> simp [Nat.shiftRight_add]

theorem shiftRight_eq_div_pow (m : Int) (n : Nat) :
m >>> n = m / ((2 ^ n) : Nat) := by
simp only [shiftRight_eq, Int.shiftRight, Nat.shiftRight_eq_div_pow]
split
· simp
· rw [negSucc_ediv _ (by norm_cast; exact Nat.pow_pos (Nat.zero_lt_two))]
rfl

@[simp]
theorem zero_shiftRight (n : Nat) : (0 : Int) >>> n = 0 := by
simp [Int.shiftRight_eq_div_pow]

end Int
3 changes: 3 additions & 0 deletions src/Init/Data/Int/DivModLemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ theorem negSucc_emod (m : Nat) {b : Int} (bpos : 0 < b) : -[m+1] % b = b - 1 - m
match b, eq_succ_of_zero_lt bpos with
| _, ⟨n, rfl⟩ => rfl

theorem emod_negSucc (m : Nat) (n : Int) :
(Int.negSucc m) % n = Int.subNatNat (Int.natAbs n) (Nat.succ (m % Int.natAbs n)) := rfl

theorem ofNat_mod_ofNat (m n : Nat) : (m % n : Int) = ↑(m % n) := rfl

theorem emod_nonneg : ∀ (a : Int) {b : Int}, b ≠ 00 ≤ a % b
Expand Down

0 comments on commit 81f5b07

Please sign in to comment.