Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add BitVec.toFin_append #36

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
236 changes: 234 additions & 2 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,9 @@ theorem toInt_ofNat {n : Nat} (x : Nat) :
@[simp] theorem ofInt_ofNat (w n : Nat) :
BitVec.ofInt w (no_index (OfNat.ofNat n)) = BitVec.ofNat w (OfNat.ofNat n) := rfl

@[simp] theorem ofInt_toInt (x : BitVec w) : BitVec.ofInt w (x.toInt) = x := by
by_cases h : 2 * x.toNat < 2^w <;> ext <;> simp [getLsbD, h, BitVec.toInt]

theorem toInt_neg_iff {w : Nat} {x : BitVec w} :
BitVec.toInt x < 0 ↔ 2 ^ w ≤ 2 * x.toNat := by
simp [toInt_eq_toNat_cond]; omega
Expand All @@ -520,6 +523,9 @@ theorem eq_zero_or_eq_one (a : BitVec 1) : a = 0#1 ∨ a = 1#1 := by
theorem toInt_zero {w : Nat} : (0#w).toInt = 0 := by
simp [BitVec.toInt, show 0 < 2^w by exact Nat.two_pow_pos w]

@[simp] theorem toInt_cast (h : w = v) (x : BitVec w) : (cast h x).toInt = x.toInt := by
simp [toInt_eq_toNat_bmod, h]

/-! ### slt -/

/--
Expand Down Expand Up @@ -569,6 +575,10 @@ theorem zeroExtend_eq_setWidth {v : Nat} {x : BitVec w} :
(x.setWidth v).toInt = Int.bmod x.toNat (2^v) := by
simp [toInt_eq_toNat_bmod, toNat_setWidth, Int.emod_bmod]

@[simp] theorem toFin_setWidth (x : BitVec w) :
(x.setWidth v).toFin = Fin.ofNat' (2^v) x.toNat := by
ext; simp

theorem setWidth'_eq {x : BitVec w} (h : w ≤ v) : x.setWidth' h = x.setWidth v := by
apply eq_of_toNat_eq
rw [toNat_setWidth, toNat_setWidth']
Expand Down Expand Up @@ -645,6 +655,21 @@ theorem getElem?_setWidth (m : Nat) (x : BitVec n) (i : Nat) :
getLsbD (setWidth m x) i = (decide (i < m) && getLsbD x i) := by
simp [getLsbD, toNat_setWidth, Nat.testBit_mod_two_pow]

theorem getMsbD_setWidth {m : Nat} {x : BitVec n} {i : Nat} :
getMsbD (setWidth m x) i = (decide (m - n ≤ i) && getMsbD x (i + n - m)) := by
unfold setWidth
by_cases h : n ≤ m <;> simp only [h]
· by_cases h' : (m - n ≤ i)
· simp [h', show (i - (m - n)) = i + n - m by omega]
· simp [h']
· simp only [show m-n = 0 by omega, getMsbD, getLsbD_setWidth]
by_cases h'' : i < m
· simp [show m - 1 - i < m by omega, show i + n - m < n by omega,
show (n - 1 - (i + n - m)) = m - 1 - i by omega]
omega
· simp [h'']
omega

@[simp] theorem getMsbD_setWidth_add {x : BitVec w} (h : k ≤ i) :
(x.setWidth (w + k)).getMsbD i = x.getMsbD (i - k) := by
by_cases h : w = 0
Expand Down Expand Up @@ -1124,6 +1149,11 @@ theorem not_eq_comm {x y : BitVec w} : ~~~ x = y ↔ x = ~~~ y := by
BitVec.toNat (x <<< n) = BitVec.toNat x <<< n % 2^v :=
BitVec.toNat_ofNat _ _

@[simp] theorem toInt_shiftLeft {n : Nat} {x : BitVec v} :
BitVec.toInt (x <<< n) = Int.bmod (BitVec.toInt x * 2^n) (2^v) := by
simp [toInt_eq_toNat_bmod, Nat.shiftLeft_eq]
norm_cast

@[simp] theorem toFin_shiftLeft {n : Nat} (x : BitVec w) :
BitVec.toFin (x <<< n) = Fin.ofNat' (2^w) (x.toNat <<< n) := rfl

Expand Down Expand Up @@ -1223,6 +1253,7 @@ theorem shiftLeftZeroExtend_eq {x : BitVec w} :
(shiftLeftZeroExtend x i).msb = x.msb := by
simp [shiftLeftZeroExtend_eq, BitVec.msb]


theorem shiftLeft_add {w : Nat} (x : BitVec w) (n m : Nat) :
x <<< (n + m) = (x <<< n) <<< m := by
ext i
Expand Down Expand Up @@ -1616,7 +1647,7 @@ private theorem Int.negSucc_emod (m : Nat) (n : Int) :
-(m + 1) % n = Int.subNatNat (Int.natAbs n) ((m % Int.natAbs n) + 1) := rfl

/-- The sign extension is the same as zero extending when `msb = false`. -/
theorem signExtend_eq_not_setWidth_not_of_msb_false {x : BitVec w} {v : Nat} (hmsb : x.msb = false) :
theorem signExtend_eq_setWidth_of_msb_false {x : BitVec w} {v : Nat} (hmsb : x.msb = false) :
x.signExtend v = x.setWidth v := by
ext i
by_cases hv : i < v
Expand Down Expand Up @@ -1652,16 +1683,33 @@ theorem signExtend_eq_not_setWidth_not_of_msb_true {x : BitVec w} {v : Nat} (hms
theorem getLsbD_signExtend (x : BitVec w) {v i : Nat} :
(x.signExtend v).getLsbD i = (decide (i < v) && if i < w then x.getLsbD i else x.msb) := by
rcases hmsb : x.msb with rfl | rfl
· rw [signExtend_eq_not_setWidth_not_of_msb_false hmsb]
· rw [signExtend_eq_setWidth_of_msb_false hmsb]
by_cases (i < v) <;> by_cases (i < w) <;> simp_all <;> omega
· rw [signExtend_eq_not_setWidth_not_of_msb_true hmsb]
by_cases (i < v) <;> by_cases (i < w) <;> simp_all <;> omega

theorem getMsbD_signExtend {x : BitVec w} {v i : Nat} :
(x.signExtend v).getMsbD i =
(decide (i < v) && if v - w ≤ i then x.getMsbD (i + w - v) else x.msb) := by
rcases hmsb : x.msb with rfl | rfl
· simp [signExtend_eq_setWidth_of_msb_false hmsb, getMsbD_setWidth]
by_cases h : (v - w ≤ i) <;> simp [h, getMsbD] <;> omega
· simp only [signExtend_eq_not_setWidth_not_of_msb_true hmsb, getMsbD_not,
getMsbD_setWidth, Bool.not_and, Bool.not_not, Bool.if_true_right]
by_cases h : i < v <;> by_cases h' : v - w ≤ i <;> simp [h, h'] <;> omega

theorem getElem_signExtend {x : BitVec w} {v i : Nat} (h : i < v) :
(x.signExtend v)[i] = if i < w then x.getLsbD i else x.msb := by
rw [←getLsbD_eq_getElem, getLsbD_signExtend]
simp [h]

theorem msb_SignExtend {x : BitVec w} :
(x.signExtend v).msb = (decide (0 < v) && if w ≥ v then x.getMsbD (w - v) else x.msb) := by
simp [BitVec.msb, getMsbD_signExtend]
by_cases h : w ≥ v
· simp [h, show v - w = 0 by omega]
· simp [h, show ¬ (v - w = 0) by omega]

/-- Sign extending to a width smaller than the starting width is a truncation. -/
theorem signExtend_eq_setWidth_of_lt (x : BitVec w) {v : Nat} (hv : v ≤ w):
x.signExtend v = x.setWidth v := by
Expand Down Expand Up @@ -1759,6 +1807,23 @@ theorem append_def (x : BitVec v) (y : BitVec w) :
(x ++ y).toNat = x.toNat <<< n ||| y.toNat :=
rfl

/-- Helper theorem to show that the expression in `(x ++ y).toFin` is inbounds. -/
theorem toNat_append_lt {m n : Nat} (x : BitVec m) (y : BitVec n) :
x.toNat <<< n ||| y.toNat < 2 ^ (m + n) := by
have hnLe : 2^n ≤ 2 ^(m + n) := by
rw [Nat.pow_add]
exact Nat.le_mul_of_pos_left (2 ^ n) (Nat.two_pow_pos m)
apply Nat.or_lt_two_pow
· have := Nat.two_pow_pos n
rw [Nat.shiftLeft_eq, Nat.pow_add, Nat.mul_lt_mul_right]
<;> omega
· omega

@[simp] theorem toFin_append {x : BitVec m} {y : BitVec n} :
(x ++ y).toFin = @Fin.mk (2^(m+n)) (x.toNat <<< n ||| y.toNat) (toNat_append_lt x y) := by
ext
Comment on lines +1822 to +1824
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more general to have h be a proof, I have had situations where lean fails to unify correctly with the proposition when we fix the proof to be a a particular proof, even if we have proof irrelevance.

simp

theorem getLsbD_append {x : BitVec n} {y : BitVec m} :
getLsbD (x ++ y) i = bif i < m then getLsbD y i else getLsbD x (i - m) := by
simp only [append_def, getLsbD_or, getLsbD_shiftLeftZeroExtend, getLsbD_setWidth']
Expand Down Expand Up @@ -1806,6 +1871,173 @@ theorem msb_append {x : BitVec w} {y : BitVec v} :
ext
simp only [getLsbD_append, getLsbD_zero, Bool.cond_self]

theorem append_zero {n m : Nat} {x : BitVec n} :
x ++ 0#m = x.signExtend (n + m) <<< m := by
induction m
case zero =>
simp [signExtend]
case succ i ih =>
simp [bv_toNat]
sorry

def lhs (x : BitVec n) (y : BitVec m) : Int := (x++y).toInt
def rhs (x : BitVec n) (y : BitVec m) : Int := if n == 0 then y.toInt else (x.toInt * (2^m)) + y.toNat

def eq (x: BitVec n) (y: BitVec m) : Bool := (lhs x y) = (rhs x y)


#eval (-5#10 ++ 3#2).toInt

def test : Bool := Id.run do
for i in [0, 1, 2, 3, 4, 5, 6, 7, 8] do
for j in [0, 1, 2, 3, 4, 5, 6, 7, 8] do
for n in [0, 1, 2, 3, 4] do
for m in [0, 1, 2, 3, 4] do
let x := BitVec.ofNat n i
let y := BitVec.ofNat m j
if (!eq x y) then
return false
return true

private theorem Nat.lt_mul_of_le_of_lt_of_lt {a b c : Nat} (hab : a ≤ b) (ha : 0 < a) (hc : 1 < c) :
a < b * c := by
have : a * 1 < b * c := Nat.mul_lt_mul_of_le_of_lt' (by omega) (by simp [hc]) (by omega)
simp at this
simp [this]

private theorem Nat.two_pow_lt_two_pow_add {n m : Nat} (h : m ≠ 0) :
2 ^ n < 2 ^ (n + m) := by
apply Nat.pow_lt_pow_of_lt (by omega) (by omega)

@[simp] theorem signExtend_shiftLeft_msb {n m : Nat} {x : BitVec n} :
(signExtend (n + m) x <<< m).msb = x.msb := by
induction m
case zero =>
simp [signExtend]
case succ i ih =>
rw [← ih]
rw [msb_setWidth]

unfold BitVec.msb getMsbD
simp
by_cases h : (0 < n + i)
·
rw [← Nat.add_assoc]
simp [h]
have h' : (0 < n + i + 1) := by omega
have hh : (n + i - (i + 1)) = (n + i - i - 1) := by
omega
rw [hh]
simp
have hhh : (n + i - 1 - i) = (n + i - i - 1) := by omega
rw [hhh]
simp
rw [getLsbD_signExtend]






simp [BitVec.msb, getMsbD]

by_cases h : 0 < n + (i + 1)
· simp [h]

sorry
· simp [h]
sorry

@[simp] theorem signExtend_toNat_shift_mod :
((signExtend (n + m) x).toNat <<< m) % ↑(2 ^ (n + m)) = (signExtend (n + m) x).toNat <<< m :=
sorry

@[simp] theorem toInt_append_zero {n m : Nat} {x : BitVec n} :
(x ++ 0#m).toInt = x.toInt * (2 ^ m) := by
by_cases m0 : m = 0
· subst m0
simp
· simp only [ofNat_eq_ofNat, append_zero, toInt_eq_msb_cond]
by_cases h1 : (signExtend (n + m) x <<< m).msb
· by_cases h2: x.msb
· norm_cast
simp [h1, h2]
norm_cast
rw [Int.sub_mul, Nat.pow_add]
norm_cast
simp
rw [Nat.shiftLeft_eq]
norm_cast
have aa := @Nat.pow_pos 2 m (by omega)
norm_cast
have bb := @Nat.mul_right_cancel_iff (2^m) ((signExtend (n + m) x).toNat)
apply bb
rfl
rw [Nat.mul_right_cancel (m := 2 ^ m)]
simp [aa]
rw [Nat.mod_eq_of_lt (a := x.toNat) (by omega)]
norm_cast
simp [h3]
simp_all
rw [Nat.shiftLeft_eq]
· simp only [signExtend_shiftLeft_of_lt] at h1
contradiction
· by_cases h2: x.msb
· simp [signExtend_shiftLeft_of_lt, h2] at h1
· sorry

@[simp] theorem toInt_append {x : BitVec n} {y : BitVec m} :
(x ++ y).toInt = if n == 0 then y.toInt else x.toInt * (2 ^ m) + y.toNat := by
by_cases n0 : n = 0
· subst n0
simp [BitVec.eq_nil x]
· by_cases m0 : m = 0
· subst m0
simp [BitVec.eq_nil y, n0]
· simp [m0]
by_cases y0 : y = 0
· simp [toInt_append_zero, y0, n0]
rw [toInt_eq_toNat_cond]
rw [toInt_eq_toNat_cond]
split
·
split
<;> norm_cast
<;> simp
<;> rw [Nat.mod_eq_of_lt (a := x.toNat) (by omega)]
<;> norm_cast
<;> simp [h3]
<;> simp_all
· rw [Nat.shiftLeft_eq]
· rename_i aa bb
rw [Nat.shiftLeft_eq] at aa
rw [Nat.pow_add] at aa
rw [← Nat.mul_assoc] at aa

sorry
·
split
<;> norm_cast
<;> simp
<;> rw [Nat.mod_eq_of_lt (a := x.toNat) (by omega)]
<;> norm_cast
<;> simp [h3]
<;> simp_all
· rename_i aa bb
rw [Nat.shiftLeft_eq] at aa
rw [Nat.pow_add] at aa
rw [← Nat.mul_assoc] at aa




simp_all


sorry
· simp [Nat.shiftLeft_eq, Int.sub_mul, Nat.pow_add]
· sorry

@[simp] theorem cast_append_right (h : w + v = w + v') (x : BitVec w) (y : BitVec v) :
cast h (x ++ y) = x ++ cast (by omega) y := by
ext
Expand Down
Loading