From 9f51eb4db371294cfd572c145136fa8a5ba15c65 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Mon, 14 Oct 2024 21:33:46 -0500 Subject: [PATCH] feat: fix BitVec.abs, prove toInt produces the expected value. The previous definition of `abs` was incorrect when only the msb was `1` and all other bits were `0`. For example, consider bit-width 3: ``` 100 -- 4#3 ``` If we compute `-x`, i.e. `!x + 1`, we get: ``` 011 +001 --- 100 ``` We recover `4#3` once again. The semantically correct implementation can use `BitVec.slt`, and we can prove the equivalence to the bit-fiddling hack: ``` // https://math.stackexchange.com/q/2565736/261373 int iabs(int a) { int t = a >> 31; a = (a^t) - t; return a; } ``` written in lean, this is: ``` def BitVec.abs' (x : BitVec w) : let t := x >> (w - 1) (x ^^^ t) - t ``` --- src/Init/Data/BitVec/Basic.lean | 11 +++--- src/Init/Data/BitVec/Lemmas.lean | 60 +++++++++++++++++++++++++++----- src/Init/Data/Int/Basic.lean | 7 ++++ src/Init/Data/Int/Lemmas.lean | 24 +++++++++++++ 4 files changed, 89 insertions(+), 13 deletions(-) diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean index 432db5296a1f..df45a9d9ef22 100644 --- a/src/Init/Data/BitVec/Basic.lean +++ b/src/Init/Data/BitVec/Basic.lean @@ -263,11 +263,6 @@ SMT-Lib name: `bvneg`. protected def neg (x : BitVec n) : BitVec n := .ofNat n (2^n - x.toNat) instance : Neg (BitVec n) := ⟨.neg⟩ -/-- -Return the absolute value of a signed bitvector. --/ -protected def abs (x : BitVec n) : BitVec n := if x.msb then .neg x else x - /-- Multiplication for bit vectors. This can be interpreted as either signed or unsigned multiplication modulo `2^n`. @@ -422,6 +417,12 @@ protected def sle (x y : BitVec n) : Bool := x.toInt ≤ y.toInt end relations +/-- +Return the absolute value of a signed bitvector. +-/ +protected def abs (x : BitVec n) : BitVec n := if x.toInt < 0 then .neg x else x + + section cast /-- `cast eq x` embeds `x` into an equal `BitVec` type. -/ diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index ba6260e0d488..6951bf1f9a2d 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -11,6 +11,7 @@ import Init.Data.Fin.Lemmas import Init.Data.Nat.Lemmas import Init.Data.Nat.Mod import Init.Data.Int.Bitwise.Lemmas +import Init.Data.Int.Lemmas import Init.Data.Int.Pow set_option linter.missingDocs true @@ -206,6 +207,7 @@ theorem eq_of_getMsbD_eq {x y : BitVec w} theorem of_length_zero {x : BitVec 0} : x = 0#0 := by ext; simp theorem toNat_zero_length (x : BitVec 0) : x.toNat = 0 := by simp [of_length_zero] +theorem toInt_length_zero (x : BitVec 0) : x.toInt = 0 := by simp [of_length_zero] theorem getLsbD_zero_length (x : BitVec 0) : x.getLsbD i = false := by simp theorem getMsbD_zero_length (x : BitVec 0) : x.getMsbD i = false := by simp theorem msb_zero_length (x : BitVec 0) : x.msb = false := by simp [BitVec.msb, of_length_zero] @@ -2070,16 +2072,58 @@ theorem smod_zero {x : BitVec n} : x.smod 0#n = x := by · simp · by_cases h : x = 0#n <;> simp [h] +/-! ### slt -/ + +def BitVec.slt_def {x y : BitVec n} : (x.slt y) = (x.toInt < y.toInt) := by + simp + + /-! ### abs -/ -@[simp, bv_toNat] -theorem toNat_abs {x : BitVec w} : x.abs.toNat = if x.msb then 2^w - x.toNat else x.toNat := by - simp only [BitVec.abs, neg_eq] - by_cases h : x.msb = true - · simp only [h, ↓reduceIte, toNat_neg] - have : 2 * x.toNat ≥ 2 ^ w := BitVec.msb_eq_true_iff_two_mul_ge.mp h - rw [Nat.mod_eq_of_lt (by omega)] - · simp [h] +private theorem two_pow_plus_one_div_two (w : Nat) : ((2^w + 1) / 2) = 2^(w - 1) := by + apply Nat.div_eq_of_lt_le + · rcases w with rfl | w + · decide + · simp + omega + · rcases w with rfl | w + · decide + · rw [Nat.add_one_sub_one, Nat.add_mul, + Nat.one_mul, Nat.add_lt_add_iff_right, + Nat.pow_add] + omega + +theorem abs_def {x : BitVec w} : x.abs = if x.toInt < 0 then -x else x := rfl + +/-- The value of the bitvector (interpreted as an integer) is always less than 2^w -/ +theorem toInt_lt (x : BitVec w) : x.toInt < 2 ^ w := by + rw [toInt_eq_msb_cond] + norm_cast + omega + +/-- The negation value of the bitvector (interpreted as an integer) is always less than 2^w -/ +theorem neg_toInt_lt (x : BitVec w) : - x.toInt < 2 ^ w := by + have := toInt_lt x + rw [toInt_eq_msb_cond] + split + case isTrue h => + simp only [gt_iff_lt] + norm_cast + have := msb_eq_true_iff_two_mul_ge.mp h + omega + case isFalse h => + norm_cast + omega + +theorem toInt_abs (x : BitVec w) : x.abs.toInt = x.toInt.abs := by + · rw [abs_def] + split + case isTrue h => + rw [Int.Int.abs_eq_neg (by omega)] + sorry + case isFalse h => + rw [Int.abs_eq_self (by omega)] + /-! ### mul -/ diff --git a/src/Init/Data/Int/Basic.lean b/src/Init/Data/Int/Basic.lean index dbf661c4be1b..fed9719af55b 100644 --- a/src/Init/Data/Int/Basic.lean +++ b/src/Init/Data/Int/Basic.lean @@ -333,6 +333,13 @@ instance : Min Int := minOfLe instance : Max Int := maxOfLe +/-- +Return the absolute value of an integer. +-/ +def abs : Int → Int + | ofNat n => .ofNat n + | negSucc n => .ofNat n.succ + end Int /-- diff --git a/src/Init/Data/Int/Lemmas.lean b/src/Init/Data/Int/Lemmas.lean index 4b0e560fb00d..4b555d825373 100644 --- a/src/Init/Data/Int/Lemmas.lean +++ b/src/Init/Data/Int/Lemmas.lean @@ -531,4 +531,28 @@ theorem natCast_one : ((1 : Nat) : Int) = (1 : Int) := rfl @[simp] theorem natCast_mul (a b : Nat) : ((a * b : Nat) : Int) = (a : Int) * (b : Int) := by simp +/-! abs lemmas -/ + +@[simp] +theorem abs_eq_self {x : Int} (h : x ≥ 0) : x.abs = x := by + cases x + case ofNat h => + rfl + case negSucc h => + contradiction + +@[simp] +theorem Int.abs_zero : Int.abs 0 = 0 := rfl + +@[simp] +theorem abs_eq_neg {x : Int} (h : x < 0) : x.abs = -x := by + cases x + case ofNat h => + contradiction + case negSucc n => + rfl + +@[simp] +theorem ofNat_abs (x : Nat) : (x : Int).abs = (x : Int) := rfl + end Int