From 4a7b7ae17045365d7ae6d670309472ddea809383 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Mon, 14 Oct 2024 21:33:46 -0500 Subject: [PATCH] feat: toInt_abs We implement `toInt_abs`. A subtle wrinkle is to note that `abs (intMin w) = intMin w`, which complicates our proof. --- src/Init/Data/BitVec/Lemmas.lean | 139 ++++++++++++++++++++++++++++--- src/Init/Data/Int/Basic.lean | 7 ++ src/Init/Data/Int/Lemmas.lean | 24 ++++++ 3 files changed, 159 insertions(+), 11 deletions(-) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index ba6260e0d488..b01bcb93656f 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -206,6 +206,7 @@ theorem eq_of_getMsbD_eq {x y : BitVec w} theorem of_length_zero {x : BitVec 0} : x = 0#0 := by ext; simp theorem toNat_zero_length (x : BitVec 0) : x.toNat = 0 := by simp [of_length_zero] +theorem toInt_zero_length (x : BitVec 0) : x.toInt = 0 := by simp [of_length_zero] theorem getLsbD_zero_length (x : BitVec 0) : x.getLsbD i = false := by simp theorem getMsbD_zero_length (x : BitVec 0) : x.getMsbD i = false := by simp theorem msb_zero_length (x : BitVec 0) : x.msb = false := by simp [BitVec.msb, of_length_zero] @@ -353,7 +354,19 @@ theorem msb_eq_getLsbD_last (x : BitVec w) : @[bv_toNat] theorem getLsbD_succ_last (x : BitVec (w + 1)) : x.getLsbD w = decide (2 ^ w ≤ x.toNat) := getLsbD_last x -@[bv_toNat] theorem msb_eq_decide (x : BitVec w) : BitVec.msb x = decide (2 ^ (w-1) ≤ x.toNat) := by + +/-- +An alternative to `msb_eq_decide` in terms of `2 * x.toNat`, +in order to avoid `2 ^ (w - 1)`. +-/ +@[bv_toNat] theorem msb_eq_decide_le_mul_two (x : BitVec w) : + BitVec.msb x = decide (2 ^ w ≤ 2 * x.toNat) := by + rw [x.msb_eq_getLsbD_last, x.getLsbD_last] + simp + rcases w with rfl | w <;> simp <;> omega + +@[bv_toNat, deprecated msb_eq_decide_le_mul_two (since := "21-10-2024") ] +theorem msb_eq_decide (x : BitVec w) : BitVec.msb x = decide (2 ^ (w-1) ≤ x.toNat) := by simp [msb_eq_getLsbD_last, getLsbD_last] theorem toNat_ge_of_msb_true {x : BitVec n} (p : BitVec.msb x = true) : x.toNat ≥ 2^(n-1) := by @@ -463,6 +476,38 @@ theorem toInt_pos_iff {w : Nat} {x : BitVec w} : 0 ≤ BitVec.toInt x ↔ 2 * x.toNat < 2 ^ w := by simp [toInt_eq_toNat_cond]; omega +/- +If `x.msb` is false, then the value of `x` when interpreted as a 2s complement +integer is between `[0..2^n/2)`. +To avoid the corner case at `n = 0`, we phrase the bounds as `2 * x < 2^n` instead of `x < 2^(n-1)`. +-/ +theorem toInt_bounds_of_msb_eq_false {x : BitVec n} (hmsb : x.msb = false) : + 0 ≤ x.toInt ∧ 2 * x.toInt < 2^n := by + have := x.msb_eq_decide_le_mul_two + rw [hmsb] at this + simp only [false_eq_decide_iff, Nat.not_le] at this + rw [BitVec.toInt_eq_toNat_cond] + simp [this] + apply And.intro + · omega + · norm_cast + +/- +If `x.msb` is true, then the value of `x` when interpreted as a 2s complement +integer is between `[-2^n..0). +-/ +theorem toInt_bounds_of_msb_eq_true {x : BitVec n} (hmsb : x.msb = true) : + -2^n ≤ x.toInt ∧ x.toInt < 0 := by + have := x.msb_eq_decide_le_mul_two + rw [hmsb] at this + simp only [true_eq_decide_iff] at this + rw [BitVec.toInt_eq_toNat_cond] + simp [show ¬ 2 * x.toNat < 2 ^ n by omega] + apply And.intro + · norm_cast + omega + · omega + theorem eq_zero_or_eq_one (a : BitVec 1) : a = 0#1 ∨ a = 1#1 := by obtain ⟨a, ha⟩ := a simp only [Nat.reducePow] @@ -2070,16 +2115,6 @@ theorem smod_zero {x : BitVec n} : x.smod 0#n = x := by · simp · by_cases h : x = 0#n <;> simp [h] -/-! ### abs -/ - -@[simp, bv_toNat] -theorem toNat_abs {x : BitVec w} : x.abs.toNat = if x.msb then 2^w - x.toNat else x.toNat := by - simp only [BitVec.abs, neg_eq] - by_cases h : x.msb = true - · simp only [h, ↓reduceIte, toNat_neg] - have : 2 * x.toNat ≥ 2 ^ w := BitVec.msb_eq_true_iff_two_mul_ge.mp h - rw [Nat.mod_eq_of_lt (by omega)] - · simp [h] /-! ### mul -/ @@ -2643,6 +2678,52 @@ theorem toInt_neg_of_ne_intMin {x : BitVec w} (rs : x ≠ intMin w) : have := @Nat.two_pow_pred_mul_two w (by omega) split <;> split <;> omega +/-- The msb of `intMin w` is `true` for all `w > 0` -/ +@[simp] theorem msb_intMin : (intMin w).msb = decide (w > 0) := by + rw [intMin] + rw [msb_eq_decide] + simp + rcases w with rfl | w + · rfl + · simp + have : 0 < 2^w := Nat.pow_pos (by decide) + have : 2^w < 2^(w + 1) := by + rw [Nat.pow_succ] + omega + rw [Nat.mod_eq_of_lt (by omega)] + simp + +/-- +If the width is zero, then `intMin` is `0`, +and otherwise it is `-2^(n - 1)`. +-/ +theorem toInt_intMin_eq_if (n : Nat) : (BitVec.intMin n).toInt = + if n = 0 then 0 else - 2^(n - 1) := by + simp [BitVec.toInt_intMin] + rcases n with rfl | n + · simp + · simp + norm_cast + have : 2^n > 0 := by exact Nat.two_pow_pos n + have : 2^n < 2^(n + 1) := by + simp [Nat.pow_succ] + omega + rw [Nat.mod_eq_of_lt (by omega)] + +/-- +Negating `intMin` returns `intMin`. +Thus, converting `(-x)` to an `Int` return `- x.toInt` for all bitvectors other than `intMin`. +-/ +theorem toInt_neg_eq_if {x : BitVec n} : + (-x).toInt = + if x = intMin n + then x.toInt + else - x.toInt := by + by_cases hx : x = intMin n + · simp [hx] + · simp [hx] + rw [toInt_neg_of_ne_intMin hx] + /-! ### intMax -/ /-- The bitvector of width `w` that has the largest value when interpreted as an integer. -/ @@ -2674,6 +2755,42 @@ theorem getLsbD_intMax (w : Nat) : (intMax w).getLsbD i = decide (i + 1 < w) := · rw [Nat.sub_add_cancel (Nat.two_pow_pos (w - 1)), Nat.two_pow_pred_mod_two_pow (by omega)] +/-! ### abs -/ + +theorem abs_def {x : BitVec w} : x.abs = if x.msb then .neg x else x := rfl + +theorem abs_eq_if (x : BitVec w) : x.abs = + if x.msb = true then + if x = BitVec.intMin w then (BitVec.intMin w) else -x + else x := by + · rw [BitVec.abs_def] + by_cases hx : x.msb = true <;> by_cases hx' : x = BitVec.intMin w <;> simp [hx, hx'] + +theorem toInt_abs (x : BitVec w) : + x.abs.toInt = if x = (intMin w) then if w = 0 then 0 else - 2^(w - 1) else x.toInt.abs := by + rcases w with rfl | w + · simp [toInt_zero_length] + · simp only [gt_iff_lt, Nat.zero_lt_succ, Nat.add_one_ne_zero, ↓reduceIte] + rw [BitVec.abs_eq_if] + by_cases hx : x = intMin (w + 1) + · simp only [hx, reduceIte] + have := BitVec.msb_intMin (w := w + 1) + rw [this] + simp only [gt_iff_lt, Nat.zero_lt_succ, decide_True, ↓reduceIte] + rw [BitVec.toInt_intMin_eq_if] + simp + · simp only [hx, reduceIte] + rcases hmsb : x.msb + · simp only [Bool.false_eq_true, ↓reduceIte] + have := BitVec.toInt_bounds_of_msb_eq_false hmsb + rw [Int.abs_eq_self] + omega + · simp only [reduceIte] + have hxbounds := BitVec.toInt_bounds_of_msb_eq_true hmsb + rw [BitVec.toInt_neg_eq_if] + simp only [hx, reduceIte] + rw [Int.abs_eq_neg (by omega)] + /-! ### Non-overflow theorems -/ /-- If `x.toNat * y.toNat < 2^w`, then the multiplication `(x * y)` does not overflow. -/ diff --git a/src/Init/Data/Int/Basic.lean b/src/Init/Data/Int/Basic.lean index dbf661c4be1b..fed9719af55b 100644 --- a/src/Init/Data/Int/Basic.lean +++ b/src/Init/Data/Int/Basic.lean @@ -333,6 +333,13 @@ instance : Min Int := minOfLe instance : Max Int := maxOfLe +/-- +Return the absolute value of an integer. +-/ +def abs : Int → Int + | ofNat n => .ofNat n + | negSucc n => .ofNat n.succ + end Int /-- diff --git a/src/Init/Data/Int/Lemmas.lean b/src/Init/Data/Int/Lemmas.lean index 4b0e560fb00d..4b555d825373 100644 --- a/src/Init/Data/Int/Lemmas.lean +++ b/src/Init/Data/Int/Lemmas.lean @@ -531,4 +531,28 @@ theorem natCast_one : ((1 : Nat) : Int) = (1 : Int) := rfl @[simp] theorem natCast_mul (a b : Nat) : ((a * b : Nat) : Int) = (a : Int) * (b : Int) := by simp +/-! abs lemmas -/ + +@[simp] +theorem abs_eq_self {x : Int} (h : x ≥ 0) : x.abs = x := by + cases x + case ofNat h => + rfl + case negSucc h => + contradiction + +@[simp] +theorem Int.abs_zero : Int.abs 0 = 0 := rfl + +@[simp] +theorem abs_eq_neg {x : Int} (h : x < 0) : x.abs = -x := by + cases x + case ofNat h => + contradiction + case negSucc n => + rfl + +@[simp] +theorem ofNat_abs (x : Nat) : (x : Int).abs = (x : Int) := rfl + end Int