From 2f5ce555bd73819b208d24b2b67dce85bf4a2d28 Mon Sep 17 00:00:00 2001 From: Tobias Grosser Date: Sat, 19 Oct 2024 06:42:50 +0100 Subject: [PATCH] feat: add BitVec.toInt_sub This requires us to expand the theory of `Int.bmod`, add `Int.natCast_sub`, as well as `Int.ofNat_sub_ofNat` and a couple of related theorems. --- src/Init/Data/BitVec/Lemmas.lean | 10 ++++++++ src/Init/Data/Int/DivModLemmas.lean | 30 ++++++++++++++++++++++ src/Init/Data/Int/Lemmas.lean | 39 +++++++++++++++++++++++++++++ src/Init/Omega/Int.lean | 8 ++---- 4 files changed, 81 insertions(+), 6 deletions(-) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 8c7471551e0f..fa010ead49a0 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -316,6 +316,12 @@ theorem getLsbD_ofNat (n : Nat) (x : Nat) (i : Nat) : simp [Nat.sub_sub_eq_min, Nat.min_eq_right] omega +@[simp] theorem sub_add_bmod_cancel {x y : BitVec w} : + ((((2 ^ w : Nat) - y.toNat) : Int) + x.toNat).bmod (2 ^ w) = + ((x.toNat : Int) - y.toNat).bmod (2 ^ w) := by + rw [Int.sub_eq_add_neg, Int.add_assoc, Int.add_comm, Int.bmod_add_cancel, Int.add_comm, + Int.sub_eq_add_neg] + private theorem lt_two_pow_of_le {x m n : Nat} (lt : x < 2 ^ m) (le : m ≤ n) : x < 2 ^ n := Nat.lt_of_lt_of_le lt (Nat.pow_le_pow_of_le_right (by trivial : 0 < 2) le) @@ -1974,6 +1980,10 @@ theorem sub_def {n} (x y : BitVec n) : x - y = .ofNat n ((2^n - y.toNat) + x.toN @[simp] theorem toNat_sub {n} (x y : BitVec n) : (x - y).toNat = (((2^n - y.toNat) + x.toNat) % 2^n) := rfl +@[simp, bv_toNat] theorem toInt_sub (x y : BitVec w) : + (x - y).toInt = (x.toInt - y.toInt).bmod (2^w) := by + simp [toInt_eq_toNat_bmod, Int.natCast_sub (2 ^ w) y.toNat (by omega)] + -- We prefer this lemma to `toNat_sub` for the `bv_toNat` simp set. -- For reasons we don't yet understand, unfolding via `toNat_sub` sometimes -- results in `omega` generating proof terms that are very slow in the kernel. diff --git a/src/Init/Data/Int/DivModLemmas.lean b/src/Init/Data/Int/DivModLemmas.lean index 6750a468a553..018fd49d67b8 100644 --- a/src/Init/Data/Int/DivModLemmas.lean +++ b/src/Init/Data/Int/DivModLemmas.lean @@ -1125,6 +1125,17 @@ theorem emod_add_bmod_congr (x : Int) (n : Nat) : Int.bmod (x%n + y) n = Int.bmo simp [Int.emod_def, Int.sub_eq_add_neg] rw [←Int.mul_neg, Int.add_right_comm, Int.bmod_add_mul_cancel] +@[simp] +theorem emod_sub_bmod_congr (x : Int) (n : Nat) : Int.bmod (x%n - y) n = Int.bmod (x - y) n := by + simp [Int.emod_def, Int.sub_eq_add_neg] + rw [←Int.mul_neg, Int.add_right_comm, Int.bmod_add_mul_cancel] + +@[simp] +theorem sub_emod_bmod_congr (x : Int) (n : Nat) : Int.bmod (x - y%n) n = Int.bmod (x - y) n := by + simp [Int.emod_def] + rw [Int.sub_eq_add_neg, Int.neg_sub, Int.sub_eq_add_neg, ← Int.add_assoc, Int.add_right_comm, + Int.bmod_add_mul_cancel, Int.sub_eq_add_neg] + @[simp] theorem emod_mul_bmod_congr (x : Int) (n : Nat) : Int.bmod (x%n * y) n = Int.bmod (x * y) n := by simp [Int.emod_def, Int.sub_eq_add_neg] @@ -1140,9 +1151,28 @@ theorem bmod_add_bmod_congr : Int.bmod (Int.bmod x n + y) n = Int.bmod (x + y) n rw [Int.sub_eq_add_neg, Int.add_right_comm, ←Int.sub_eq_add_neg] simp +@[simp] +theorem bmod_sub_bmod_congr : Int.bmod (Int.bmod x n - y) n = Int.bmod (x - y) n := by + rw [Int.bmod_def x n] + split + next p => + simp only [emod_sub_bmod_congr] + next p => + rw [Int.sub_eq_add_neg, Int.sub_eq_add_neg, Int.add_right_comm, ←Int.sub_eq_add_neg, ← Int.sub_eq_add_neg] + simp [emod_sub_bmod_congr] + @[simp] theorem add_bmod_bmod : Int.bmod (x + Int.bmod y n) n = Int.bmod (x + y) n := by rw [Int.add_comm x, Int.bmod_add_bmod_congr, Int.add_comm y] +@[simp] theorem sub_bmod_bmod : Int.bmod (x - Int.bmod y n) n = Int.bmod (x - y) n := by + rw [Int.bmod_def y n] + split + next p => + simp [sub_emod_bmod_congr] + next p => + rw [Int.sub_eq_add_neg, Int.sub_eq_add_neg, Int.neg_add, Int.neg_neg, ← Int.add_assoc, ← Int.sub_eq_add_neg] + simp [sub_emod_bmod_congr] + @[simp] theorem bmod_mul_bmod : Int.bmod (Int.bmod x n * y) n = Int.bmod (x * y) n := by rw [bmod_def x n] diff --git a/src/Init/Data/Int/Lemmas.lean b/src/Init/Data/Int/Lemmas.lean index 4b0e560fb00d..ddc32205fadc 100644 --- a/src/Init/Data/Int/Lemmas.lean +++ b/src/Init/Data/Int/Lemmas.lean @@ -21,6 +21,15 @@ theorem subNatNat_of_sub_eq_zero {m n : Nat} (h : n - m = 0) : subNatNat m n = theorem subNatNat_of_sub_eq_succ {m n k : Nat} (h : n - m = succ k) : subNatNat m n = -[k+1] := by rw [subNatNat, h] +theorem subNatNat_of_sub {m n : Nat} (h : n ≤ m) : subNatNat m n = ↑(m - n) := by + rw [subNatNat, ofNat_eq_coe] + split + case h_1 _ _ => + simp + case h_2 _ h' => + rw [Nat.sub_eq_zero_of_le h] at h' + simp at h' + @[simp] protected theorem neg_zero : -(0:Int) = 0 := rfl @[norm_cast] theorem ofNat_add (n m : Nat) : (↑(n + m) : Int) = n + m := rfl @@ -54,8 +63,32 @@ theorem negOfNat_eq : negOfNat n = -ofNat n := rfl /- ## some basic functions and properties -/ +def eq_add_one_iff_neg_eq_negSucc {m n : Nat} : m = n + 1 ↔ -↑m = -[n+1] := by + simp [Neg.neg, instNegInt, Int.neg, negOfNat] + split <;> simp [Nat.add_one_inj] + @[norm_cast] theorem ofNat_inj : ((m : Nat) : Int) = (n : Nat) ↔ m = n := ⟨ofNat.inj, congrArg _⟩ +@[local simp] theorem ofNat_sub_ofNat (m n : Nat) (h : n ≤ m) : (↑m - ↑n : Int) = ↑(m - n) := by + simp only [HSub.hSub, instHSub, Sub.sub, instSub, Int.sub] + simp only [HAdd.hAdd, instHAdd, Add.add, instAdd, Int.add] + simp only [add_eq, ofNat_eq_coe, succ_eq_add_one, sub_eq] + split + case h_1 ia ib na nb ha hb => + simp only [ofNat_eq_coe] at ha hb + symm at hb + have := (@ofNat_inj nb 0).mp + cases n <;> simp_all + case h_2 ia ib na nb ha hb => + have h' := eq_add_one_iff_neg_eq_negSucc.mpr hb + have h'' := (@ofNat_inj m na).mp ha + rw [h'] at h + rw [h''] at h + rw [subNatNat_of_sub h] + congr <;> simp_all + · simp_all + · simp_all + theorem ofNat_eq_zero : ((n : Nat) : Int) = 0 ↔ n = 0 := ofNat_inj theorem ofNat_ne_zero : ((n : Nat) : Int) ≠ 0 ↔ n ≠ 0 := not_congr ofNat_eq_zero @@ -528,6 +561,12 @@ theorem natCast_one : ((1 : Nat) : Int) = (1 : Int) := rfl -- so it still makes sense to tag the lemmas with `@[simp]`. simp +@[simp] theorem natCast_sub (a b : Nat) (h : b ≤ a) : + ((a - b : Nat) : Int) = (a : Int) - (b : Int) := by + simp [h] + -- Note this only works because of local simp attributes in this file, + -- so it still makes sense to tag the lemmas with `@[simp]`. + @[simp] theorem natCast_mul (a b : Nat) : ((a * b : Nat) : Int) = (a : Int) * (b : Int) := by simp diff --git a/src/Init/Omega/Int.lean b/src/Init/Omega/Int.lean index d3d62c4cc78c..0d3ea8e9e89b 100644 --- a/src/Init/Omega/Int.lean +++ b/src/Init/Omega/Int.lean @@ -94,13 +94,9 @@ theorem ofNat_sub_dichotomy {a b : Nat} : b ≤ a ∧ ((a - b : Nat) : Int) = a - b ∨ a < b ∧ ((a - b : Nat) : Int) = 0 := by by_cases h : b ≤ a · left - have t := Int.ofNat_sub h - simp at t - exact ⟨h, t⟩ + simp [h] · right - have t := Nat.not_le.mp h - simp [Int.ofNat_sub_eq_zero h] - exact t + simp [Int.ofNat_sub_eq_zero, Nat.not_le.mp, h] theorem ofNat_congr {a b : Nat} (h : a = b) : (a : Int) = (b : Int) := congrArg _ h