diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 3d0a304eef96..6d170e465671 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -1132,56 +1132,141 @@ theorem mulRec_succ_eq (l r : BitVec w) (s : Nat) : /-! #!/usr/bin/env python3 -# Check that 'hargonix-recurrences-statements' actually has the right statements. -# https://github.com/opencompl/lean4/pull/6 -# Theorems from: https://www21.in.tum.de/teaching/sar/SS20/7.pdf +##Check that 'hargonix-recurrences-statements' actually has the right statements. from z3 import * -# Define the `mulRec` function in Z3py -def mulRec(l : BitVecRef, r : BitVecRef, s : int): - # import pudb; pudb.set_trace() - assert isinstance(s, int) - assert isinstance(l, BitVecRef) - assert isinstance(r, BitVecRef) - cur = If(Extract(s, s, r) == 1, l << s, BitVecVal(0, w)) - if s == 0: return cur - else: return mulRec(l, r, s-1) + cur - -# Define BitVecs -w = 8 # Example width, you can adjust it as necessary -l = BitVec('l', w) -r = BitVec('r', w) - -mul_circuit = mulRec(l, r, w-1) -print(mul_circuit) - -# Define assertion -mul_circuit_correct = mul_circuit == l * r -s = Solver() -s.add(mul_circuit_correct) - -out = s.check() -print(out) +def mulExample(): + # Define the `mulRec` function in Z3py + def mulRec(l : BitVecRef, r : BitVecRef, s : int): + # import pudb; pudb.set_trace() + assert isinstance(s, int) + assert isinstance(l, BitVecRef) + assert isinstance(r, BitVecRef) + cur = If(Extract(s, s, r) == 1, l << s, BitVecVal(0, w)) + if s == 0: return cur + else: return mulRec(l, r, s-1) + cur + + # Define BitVecs + w = 8 # Example width, you can adjust it as necessary + l = BitVec('l', w) + r = BitVec('r', w) + + mul_circuit = mulRec(l, r, w-1) + print(mul_circuit) + + # Define assertion + mul_circuit_correct = mul_circuit == l * r + s = Solver() + s.add(ForAll(l, ForAll(r, mul_circuit_correct))) + + assert bool(s.check()) + + # verify what happens in mulRec for all 's' + for nbits_keep in range(1, w): + s = Solver() + s.add(ForAll(l, ForAll(r, mulRec(l, r, nbits_keep) == ZeroExt(w - nbits_keep - 1, Extract(nbits_keep, 0, l * r))))) + print(f"* checking mul eqn for width:'{nbits_keep}': '{s}'.") + assert bool(s.check()) +mulExample() -/ +@[simp] +theorem getLsb_ofBool (b : Bool) (i : Nat) : (BitVec.ofBool b).getLsb i = ((i = 0) && b) := by + rcases b with rfl | rfl + · simp [ofBool] + · simp [ofBool, getLsb_ofNat] + by_cases hi : (i = 0) + · simp [hi] + · simp [hi] + omega + +/-- zero extending a bitvector to width 1 equals the boolean of the lsb. -/ +theorem zeroExtend_one_eq_ofBool_getLsb_zero (x : BitVec w) : + x.zeroExtend 1 = BitVec.ofBool (x.getLsb 0) := by + ext i + simp [getLsb_zeroExtend, Fin.fin_one_eq_zero i] + +/-- `testBit 1 i` is true iff the index `i` equals 0. -/ +private theorem Nat.testBit_one_eq_true_iff_self_eq_zero {i : Nat} : + Nat.testBit 1 i = true ↔ i = 0 := by + cases i <;> simp + +/-- Zero extending `1#v` to `1#w` equals `1#w` when `v > 0`. -/ +theorem zeroExtend_ofNat_one_eq_ofNat_one_of_lt {v w : Nat} (hv : 0 < v): + (BitVec.ofNat v 1).zeroExtend w = BitVec.ofNat w 1 := by + ext i + obtain ⟨i, hilt⟩ := i + simp only [getLsb_zeroExtend, hilt, decide_True, getLsb_ofNat, Bool.true_and, + Bool.and_iff_right_iff_imp, decide_eq_true_eq] + intros hi1 + have hv := Nat.testBit_one_eq_true_iff_self_eq_zero.mp hi1 + omega + +@[simp] +theorem BitVec.mul_one {x : BitVec w} : x * (1#w) = x := by + apply eq_of_toNat_eq + simp [toNat_mul, Nat.mod_eq_of_lt x.isLt] + +@[simp] +theorem BitVec.mul_zero {x : BitVec w} : x * (0#w) = (0#w) := by + apply eq_of_toNat_eq + simp [toNat_mul] + +theorem BitVec.mul_add {x y z : BitVec w} : + x * (y + z) = x * y + x * z := by + apply eq_of_toNat_eq + simp + rw [Nat.mul_mod, Nat.mod_mod (y.toNat + z.toNat), + ← Nat.mul_mod, Nat.mul_add] + +/-- The Bitvector that is equal to `2^i % 2^w`, the power of 2 (`pot`). -/ +def pot {w : Nat} (i : Nat) : BitVec w := (1#w) <<< i + +theorem zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot (x : BitVec w) (i : Nat) : + zeroExtend w (x.truncate (i + 1)) = + zeroExtend w (x.truncate i) + (x &&& (BitVec.pot i)) := by + apply eq_of_toNat_eq + sorry + theorem mulRec_eq_mul_signExtend_truncate (l r : BitVec w) (s : Nat) : - (mulRec l r s) = l * r := by + (mulRec l r s) = l * ((r.truncate (s + 1)).zeroExtend w) := by induction w generalizing s case zero => apply Subsingleton.elim case succ w' hw => induction s case zero => - simp [mulRec, mulRec_zero_eq, signExtend, truncate] - sorry - case succ s' hs => sorry - --- Provable with sign extend theory. -@[simp] -theorem signExtend_eq_self (x : BitVec w) : x.signExtend w = x := sorry + simp [mulRec_zero_eq] + by_cases r.getLsb 0 + case pos hr => + simp only [hr, ↓reduceIte, truncate, zeroExtend_one_eq_ofBool_getLsb_zero, + hr, ofBool_true, ofNat_eq_ofNat] + rw [zeroExtend_ofNat_one_eq_ofNat_one_of_lt (by omega)]; simp + case neg hr => + simp [hr, zeroExtend_one_eq_ofBool_getLsb_zero] + case succ s' hs => + rw [mulRec_succ_eq] + rw [hs]; + have heq : + (if r.getLsb (s' + 1) = true then l <<< (s' + 1) else 0) = + (l * (r &&& (BitVec.pot (s' + 1)))) := by sorry + rw [heq, ← BitVec.mul_add] + rw [← zeroExtend_truncate_succ_eq_zeroExtend_truncate_add_pot] + +theorem zeroExtend_zeroExtend_of_lt (x : BitVec w) + (u v : Nat) (hi : i ≤ j) : + (x.zeroExtend i |>.zeroExtend j) = x.zeroExtend j := by + ext k + simp + intros hx; + have hk : k < j := by omega + sorry + -- omega theorem getLsb_mul (x y : BitVec w) (i : Nat) : (x * y).getLsb i = (mulRec x y w).getLsb i := by - rw [mulRec_eq_mul_signExtend_truncate] + simp [mulRec_eq_mul_signExtend_truncate] + rw [truncate] + sorry /-! ### le and lt -/