Skip to content

Commit

Permalink
add setWidth Proofs
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasgrosser committed Dec 8, 2024
1 parent 39df0cb commit dcd6e05
Showing 1 changed file with 55 additions and 21 deletions.
76 changes: 55 additions & 21 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,10 @@ theorem zeroExtend_eq_setWidth {v : Nat} {x : BitVec w} :
(x.setWidth v).toInt = Int.bmod x.toNat (2^v) := by
simp [toInt_eq_toNat_bmod, toNat_setWidth, Int.emod_bmod]

@[simp] theorem toFin_setWidth (x : BitVec w) :
(x.setWidth v).toFin = Fin.ofNat' (2^v) x.toNat := by
ext; simp

theorem setWidth'_eq {x : BitVec w} (h : w ≤ v) : x.setWidth' h = x.setWidth v := by
apply eq_of_toNat_eq
rw [toNat_setWidth, toNat_setWidth']
Expand Down Expand Up @@ -651,6 +655,21 @@ theorem getElem?_setWidth (m : Nat) (x : BitVec n) (i : Nat) :
getLsbD (setWidth m x) i = (decide (i < m) && getLsbD x i) := by
simp [getLsbD, toNat_setWidth, Nat.testBit_mod_two_pow]

@[simp] theorem getMsbD_setWidth {m : Nat} {x : BitVec n} {i : Nat} :
getMsbD (setWidth m x) i = (decide (m - n ≤ i) && getMsbD x (i + n - m)) := by
unfold setWidth
by_cases h : n ≤ m <;> simp only [h]
· by_cases h' : (m - n ≤ i)
· simp [h', show (i - (m - n)) = i + n - m by omega]
· simp [h']
· simp only [show m-n = 0 by omega, getMsbD, getLsbD_setWidth]
by_cases h'' : i < m
· simp [show m - 1 - i < m by omega, show i + n - m < n by omega,
show (n - 1 - (i + n - m)) = m - 1 - i by omega]
omega
· simp [h'']
omega

@[simp] theorem getMsbD_setWidth_add {x : BitVec w} (h : k ≤ i) :
(x.setWidth (w + k)).getMsbD i = x.getMsbD (i - k) := by
by_cases h : w = 0
Expand Down Expand Up @@ -1627,7 +1646,7 @@ private theorem Int.negSucc_emod (m : Nat) (n : Int) :
-(m + 1) % n = Int.subNatNat (Int.natAbs n) ((m % Int.natAbs n) + 1) := rfl

/-- The sign extension is the same as zero extending when `msb = false`. -/
theorem signExtend_eq_not_setWidth_not_of_msb_false {x : BitVec w} {v : Nat} (hmsb : x.msb = false) :
theorem signExtend_eq_setWidth_of_msb_false {x : BitVec w} {v : Nat} (hmsb : x.msb = false) :
x.signExtend v = x.setWidth v := by
ext i
by_cases hv : i < v
Expand Down Expand Up @@ -1663,39 +1682,54 @@ theorem signExtend_eq_not_setWidth_not_of_msb_true {x : BitVec w} {v : Nat} (hms
theorem getLsbD_signExtend (x : BitVec w) {v i : Nat} :
(x.signExtend v).getLsbD i = (decide (i < v) && if i < w then x.getLsbD i else x.msb) := by
rcases hmsb : x.msb with rfl | rfl
· rw [signExtend_eq_not_setWidth_not_of_msb_false hmsb]
· rw [signExtend_eq_setWidth_of_msb_false hmsb]
by_cases (i < v) <;> by_cases (i < w) <;> simp_all <;> omega
· rw [signExtend_eq_not_setWidth_not_of_msb_true hmsb]
by_cases (i < v) <;> by_cases (i < w) <;> simp_all <;> omega

protected theorem Nat.sub_sub_comm {n m k : Nat} : n - m - k = n - k - m := sorry

theorem getMsbD_signExtend (x : BitVec w) {v i : Nat} :
(x.signExtend v).getMsbD i = (decide (i < v) && if w+i-v < w then x.getMsbD (w+i-v) else x.msb) := by
(x.signExtend v).getMsbD i = (decide (i < v) && (if i > v-w then x.getMsbD (i-(v-w)) else x.msb)) := by
rcases hmsb : x.msb with rfl | rfl
· rw [signExtend_eq_not_setWidth_not_of_msb_false hmsb]
simp_all [getMsbD]
by_cases h' : (i < v) <;> by_cases h'': ((w+i-v < w)) <;> simp [getMsbD, h', h'']
have h''': ((v - 1 - i < v)) := by omega
simp [h''']
by_cases h5 : v ≤ w
· rw [show (w - 1 - (w + i - v)) = (v - 1 - i) by (
rw [Nat.sub_add_comm]
rw [← Nat.sub_sub]
rw [Nat.sub_sub_comm (m := 1)]
rw [Nat.sub_sub_eq_min]
rw [Nat.min_eq_right]
omega; omega)]
·
· simp [signExtend_eq_setWidth_of_msb_false hmsb]
simp [getMsbD_setWidth']


rw [signExtend_eq_not_setWidth_not_of_msb_false hmsb]
rw [getMsbD]
rw [getMsbD]
simp
by_cases h : i < v
·
simp [h]
simp [show v - 1 - i < v by omega]
by_cases h' : v - w < i
· simp [h']
simp [show (i - (v - w) < w) by omega]
rw [show (w - 1 - (i - (v-w))) = (v - 1 - i) by omega]
· simp [h']
omega
· simp [h]

· rw [signExtend_eq_not_setWidth_not_of_msb_true hmsb]
by_cases h' : (i < v) <;> by_cases h'': ((w+i-v < w)) <;> simp [getMsbD, h', h'']
have h''': ((v - 1 - i < w)) := by omega
simp[h''']
rw [getMsbD]
rw [getMsbD]
simp
by_cases h : i < v
·
simp [h]
simp [show v - 1 - i < v by omega]
by_cases h' : v - w < i
· simp [h']
simp [show (i - (v - w) < w) by omega]
simp [show (v - 1 - i < w) by omega]
rw [show (w - 1 - (i - (v-w))) = (v - 1 - i) by omega]
· simp [h']
rw [getlsbD_of]

<;> rw [show (w - 1 - (w - v + i)) = (v - 1 - i) by omega]
omega
· simp [h]

theorem getElem_signExtend {x : BitVec w} {v i : Nat} (h : i < v) :
(x.signExtend v)[i] = if i < w then x.getLsbD i else x.msb := by
Expand Down

0 comments on commit dcd6e05

Please sign in to comment.