forked from mit-plv/fiat-crypto
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Demo.v
212 lines (191 loc) · 9.91 KB
/
Demo.v
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
(* Following http://adam.chlipala.net/theses/andreser.pdf chapter 3 *)
From Coq Require Import ZArith Lia.
Require Import Crypto.Algebra.Nsatz.
Require Import Crypto.Util.Tactics.UniquePose Crypto.Util.Decidable.
Require Import Crypto.Util.Tuple Crypto.Util.Prod Crypto.Util.LetIn.
Require Import Crypto.Util.ListUtil Coq.Lists.List Crypto.Util.NatUtil.
Require Import QArith_base Qround Crypto.Util.QUtil.
Require Import Crypto.Algebra.Ring Crypto.Util.Decidable.Bool2Prop.
Import ListNotations. Local Open Scope Z_scope.
Definition runtime_mul := Z.mul.
Definition runtime_add := Z.add.
Declare Scope runtime_scope.
Delimit Scope runtime_scope with RT.
Infix "*" := runtime_mul : runtime_scope.
Infix "+" := runtime_add : runtime_scope.
Module Associational.
Definition eval (p:list (Z*Z)) : Z :=
fold_right Z.add 0%Z (map (fun t => fst t * snd t) p).
Lemma eval_nil : eval nil = 0.
Proof. trivial. Qed.
Lemma eval_cons p q : eval (p::q) = fst p * snd p + eval q.
Proof. trivial. Qed.
Lemma eval_app p q: eval (p++q) = eval p + eval q.
Proof. induction p; rewrite <-?List.app_comm_cons;
rewrite ?eval_nil, ?eval_cons; nsatz. Qed.
#[global]
Hint Rewrite eval_nil eval_cons eval_app : push_eval.
Local Ltac push := autorewrite with
push_eval push_map push_partition push_flat_map
push_fold_right push_nth_default cancel_pair.
Lemma eval_map_mul (a x:Z) (p:list (Z*Z))
: eval (List.map (fun t => (a*fst t, x*snd t)) p) = a*x*eval p.
Proof. induction p; push; nsatz. Qed.
#[global]
Hint Rewrite eval_map_mul : push_eval.
Definition mul (p q:list (Z*Z)) : list (Z*Z) :=
flat_map (fun t =>
map (fun t' =>
(fst t * fst t', (snd t * snd t')%RT))
q) p.
Lemma eval_mul p q : eval (mul p q) = eval p * eval q.
Proof. induction p; cbv [mul]; push; nsatz. Qed.
#[global]
Hint Rewrite eval_mul : push_eval.
Example base10_2digit_mul (a0:Z) (a1:Z) (b0:Z) (b1:Z) :
{ab| eval ab = eval [(10,a1);(1,a0)] * eval [(10,b1);(1,b0)]}.
eexists ?[ab].
(* Goal: eval ?ab = eval [(10,a1);(1,a0)] * eval [(10,b1);(1,b0)] *)
rewrite <-eval_mul.
(* Goal: eval ?ab = eval (mul [(10,a1);(1,a0)] [(10,b1);(1,b0)]) *)
cbv -[runtime_mul eval].
(* Goal: eval ?ab = eval [(100,(a1*b1));(10,a1*b0);(10,a0*b1);(1,a0*b0)]%RT *)
trivial. Defined.
Definition split (s:Z) (p:list (Z*Z)) : list (Z*Z) * list (Z*Z)
:= let hi_lo := partition (fun t => fst t mod s =? 0) p in
(snd hi_lo, map (fun t => (fst t / s, snd t)) (fst hi_lo)).
Lemma eval_split s p (s_nz:s<>0) :
eval (fst (split s p)) + s * eval (snd (split s p)) = eval p.
Proof. cbv [split]; induction p;
repeat match goal with
| |- context[?a/?b] =>
unique pose proof (Z_div_exact_full_2 a b ltac:(trivial) ltac:(trivial))
| _ => progress push
| _ => progress break_match
| _ => progress nsatz end. Qed.
Lemma reduction_rule a b s c (modulus_nz:s-c<>0) :
(a + s * b) mod (s - c) = (a + c * b) mod (s - c).
Proof. replace (a + s * b) with ((a + c*b) + b*(s-c)) by nsatz.
rewrite Z.add_mod,Z_mod_mult,Z.add_0_r,Z.mod_mod;trivial. Qed.
Definition reduce (s:Z) (c:list _) (p:list _) : list (Z*Z) :=
let lo_hi := split s p in fst lo_hi ++ mul c (snd lo_hi).
Lemma eval_reduce s c p (s_nz:s<>0) (modulus_nz:s-eval c<>0) :
eval (reduce s c p) mod (s - eval c) = eval p mod (s - eval c).
Proof. cbv [reduce]; push.
rewrite <-reduction_rule, eval_split; trivial. Qed.
#[global]
Hint Rewrite eval_reduce : push_eval.
End Associational.
Module Positional. Section Positional.
Context (weight : nat -> Z)
(weight_0 : weight 0%nat = 1)
(weight_nz : forall i, weight i <> 0).
Definition to_associational {n:nat} (xs:tuple Z n) : list (Z*Z)
:= combine (map weight (List.seq 0 n)) (Tuple.to_list n xs).
Definition eval {n} x := Associational.eval (@to_associational n x).
Lemma eval_to_associational {n} x :
Associational.eval (@to_associational n x) = eval x.
Proof using Type. trivial. Qed.
(* SKIP over this: zeros, add_to_nth *)
Local Ltac push := autorewrite with push_eval push_map distr_length
push_flat_map push_fold_right push_nth_default cancel_pair natsimplify.
Program Definition zeros n : tuple Z n
:= Tuple.from_list n (List.map (fun _ => 0) (List.seq 0 n)) _.
Next Obligation. push; reflexivity. Qed.
Lemma eval_zeros n : eval (zeros n) = 0.
Proof using weight_0.
cbv [eval Associational.eval to_associational zeros];
rewrite Tuple.to_list_from_list.
generalize dependent (List.seq 0 n); intro xs.
induction xs; simpl; nsatz. Qed.
Program Definition add_to_nth {n} i x : tuple Z n -> tuple Z n
:= Tuple.on_tuple (ListUtil.update_nth i (runtime_add x)) _.
Next Obligation. apply ListUtil.length_update_nth. Defined.
Lemma eval_add_to_nth {n} (i:nat) (H:(i<n)%nat) (x:Z) (xs:tuple Z n) :
eval (add_to_nth i x xs) = weight i * x + eval xs.
Proof using Type.
cbv [eval to_associational add_to_nth Tuple.on_tuple runtime_add].
rewrite !Tuple.to_list_from_list.
rewrite ListUtil.combine_update_nth_r at 1.
rewrite <-(update_nth_id i (List.combine _ _)) at 2.
rewrite <-!(ListUtil.splice_nth_equiv_update_nth_update _ _
(weight 0, 0)) by (push; lia); cbv [ListUtil.splice_nth id].
repeat match goal with
| _ => progress push
| _ => progress break_match
| _ => progress (apply Zminus_eq; ring_simplify)
| _ => rewrite <-ListUtil.map_nth_default_always
end; lia. Qed.
Hint Rewrite @eval_add_to_nth eval_zeros : push_eval.
Fixpoint place (t:Z*Z) (i:nat) : nat * Z :=
if dec (fst t mod weight i = 0)
then (i, let c := fst t / weight i in (c * snd t)%RT)
else match i with S i' => place t i' | O => (O, fst t * snd t)%RT end.
Lemma place_in_range (t:Z*Z) (n:nat) : (fst (place t n) < S n)%nat.
Proof using Type. induction n; cbv [place] in *; break_match; autorewrite with cancel_pair; try lia. Qed.
Lemma weight_place t i : weight (fst (place t i)) * snd (place t i) = fst t * snd t.
Proof using weight_0 weight_nz. induction i; cbv [place] in *; break_match; push;
repeat match goal with |- context[?a/?b] =>
unique pose proof (Z_div_exact_full_2 a b ltac:(auto) ltac:(auto))
end; nsatz. Qed.
Hint Rewrite weight_place : push_eval.
Definition from_associational n (p:list (Z*Z)) :=
List.fold_right (fun t =>
let p := place t (pred n) in
add_to_nth (fst p) (snd p) ) (zeros n) p.
Lemma eval_from_associational {n} p (n_nz:n<>O) :
eval (from_associational n p) = Associational.eval p.
Proof using weight_0 weight_nz. induction p; cbv [from_associational] in *; push; try
pose proof place_in_range a (pred n); destruct n; cbn [pred] in *; try lia; try nsatz. Qed.
Hint Rewrite @eval_from_associational : push_eval.
Section mulmod.
Context (m:Z) (m_nz:m <> 0) (s:Z) (s_nz:s <> 0)
(c:list (Z*Z)) (Hm:m = s - Associational.eval c).
Definition mulmod {n} (a b:tuple Z n) : tuple Z n
:= let a_a := to_associational a in
let b_a := to_associational b in
let ab_a := Associational.mul a_a b_a in
let abm_a := Associational.reduce s c ab_a in
from_associational n abm_a.
Lemma eval_mulmod {n} (H:(n<>0)%nat) (f g:tuple Z n) :
eval (mulmod f g) mod m = (eval f * eval g) mod m.
Proof using Hm m_nz s_nz weight_0 weight_nz. cbv [mulmod]; rewrite Hm in *; push; trivial. Qed.
End mulmod.
End Positional. End Positional.
Import Associational Positional.
Local Coercion Z.of_nat : nat >-> Z.
Local Coercion QArith_base.inject_Z : Z >-> Q.
Definition w (i:nat) : Z := 2^Qceiling((25+1/2)*i).
Example base_25_5_mul (f g:tuple Z 10) :
{ fg : tuple Z 10 | (eval w fg) mod (2^255-19)
= (eval w f * eval w g) mod (2^255-19) }.
(* manually assign names to limbs for pretty-printing *)
destruct f as [[[[[[[[[f9 f8]f7]f6]f5]f4]f3]f2]f1]f0].
destruct g as [[[[[[[[[g9 g8]g7]g6]g5]g4]g3]g2]g1]g0].
eexists ?[fg].
erewrite <-eval_mulmod with (s:=2^255) (c:=[(1,19)])
by (try eapply pow_ceil_mul_nat_nonzero; vm_decide).
(* eval w ?fg mod (2 ^ 255 - 19) = *)
(* eval w *)
(* (mulmod w (2^255) [(1, 19)] (f9,f8,f7,f6,f5,f4,f3,f2,f1,f0) *)
(* (g9,g8,g7,g6,g5,g4,g3,g2,g1,g0)) mod (2^255 - 19) *)
eapply f_equal2; [|trivial]. eapply f_equal.
(* ?fg = *)
(* mulmod w (2 ^ 255) [(1, 19)] (f9, f8, f7, f6, f5, f4, f3, f2, f1, f0) *)
(* (g9, g8, g7, g6, g5, g4, g3, g2, g1, g0) *)
cbv -[runtime_mul runtime_add]; cbv [runtime_mul runtime_add].
ring_simplify_subterms.
(* ?fg =
(f0*g9+ f1*g8+ f2*g7+ f3*g6+ f4*g5+ f5*g4+ f6*g3+ f7*g2+ f8*g1+ f9*g0,
f0*g8+ 2*f1*g7+ f2*g6+ 2*f3*g5+ f4*g4+ 2*f5*g3+ f6*g2+ 2*f7*g1+ f8*g0+ 38*f9*g9,
f0*g7+ f1*g6+ f2*g5+ f3*g4+ f4*g3+ f5*g2+ f6*g1+ f7*g0+ 19*f8*g9+ 19*f9*g8,
f0*g6+ 2*f1*g5+ f2*g4+ 2*f3*g3+ f4*g2+ 2*f5*g1+ f6*g0+ 38*f7*g9+ 19*f8*g8+ 38*f9*g7,
f0*g5+ f1*g4+ f2*g3+ f3*g2+ f4*g1+ f5*g0+ 19*f6*g9+ 19*f7*g8+ 19*f8*g7+ 19*f9*g6,
f0*g4+ 2*f1*g3+ f2*g2+ 2*f3*g1+ f4*g0+ 38*f5*g9+ 19*f6*g8+ 38*f7*g7+ 19*f8*g6+ 38*f9*g5,
f0*g3+ f1*g2+ f2*g1+ f3*g0+ 19*f4*g9+ 19*f5*g8+ 19*f6*g7+ 19*f7*g6+ 19*f8*g5+ 19*f9*g4,
f0*g2+ 2*f1*g1+ f2*g0+ 38*f3*g9+ 19*f4*g8+ 38*f5*g7+ 19*f6*g6+ 38*f7*g5+ 19*f8*g4+ 38*f9*g3,
f0*g1+ f1*g0+ 19*f2*g9+ 19*f3*g8+ 19*f4*g7+ 19*f5*g6+ 19*f6*g5+ 19*f7*g4+ 19*f8*g3+ 19*f9*g2,
f0*g0+ 38*f1*g9+ 19*f2*g8+ 38*f3*g7+ 19*f4*g6+ 38*f5*g5+ 19*f6*g4+ 38*f7*g3+ 19*f8*g2+ 38*f9*g1) *)
trivial.
Defined.
(* Eval cbv on this one would produce an ugly term due to the use of [destruct] *)