forked from LaRiffle/ariann
-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
253 lines (238 loc) · 55.8 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import time
import torch as th
import syft as sy
# fmt: off
config_zoo = {
"template": {
'fss_eq': [],
'fss_comp': [],
'mul': [],
'matmul': [],
'conv2d': []
},
"network1-mnist-1": {
'matmul': [((1, 784), (784, 128)), ((1, 128), (128, 128)), ((1, 128), (128, 10))] ,
'fss_comp': [128, 128, 10, 90, 10] ,
'mul': [((1, 128), (1, 128)), ((1, 128), (1, 128)), ((1, 10), (1, 10))] ,
'fss_eq': [1] ,
},
"network1-mnist-128": {
'fss_eq': [128],
'fss_comp': [16384, 16384, 1280, 11520, 1280],
'mul': [((128, 128), (128, 128)), ((128, 128), (128, 128)), ((128, 10), (128, 10))],
'matmul': [((128, 784), (784, 128)), ((128, 128), (128, 128)), ((128, 128), (128, 10))],
},
"network2-mnist-128": {
'conv2d': [((128, 1, 28, 28), (16, 1, 5, 5), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (0, 0), (1, 1), 1))), ((128, 16, 12, 12), (16, 16, 5, 5), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (0, 0), (1, 1), 1)))] ,
'fss_comp': [589824, 294912, 294912, 65536, 32768, 32768, 12800, 1280, 11520, 1280] ,
'mul': [((128, 16, 144, 2), (128, 16, 144, 2)), ((128, 16, 144, 1), (128, 16, 144, 1)), ((128, 16, 12, 12), (128, 16, 12, 12)), ((128, 16, 16, 2), (128, 16, 16, 2)), ((128, 16, 16, 1), (128, 16, 16, 1)), ((128, 16, 4, 4), (128, 16, 4, 4)), ((128, 100), (128, 100)), ((128, 10), (128, 10))] ,
'matmul': [((128, 256), (256, 100)), ((128, 100), (100, 10))] ,
'fss_eq': [128] ,
},
"lenet-mnist-128": {
'conv2d': [((128, 1, 28, 28), (20, 1, 5, 5), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (0, 0), (1, 1), 1))), ((128, 20, 12, 12), (50, 20, 5, 5), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (0, 0), (1, 1), 1)))] ,
'fss_comp': [737280, 368640, 368640, 204800, 102400, 102400, 64000, 1280, 11520, 1280] ,
'mul': [((128, 20, 144, 2), (128, 20, 144, 2)), ((128, 20, 144, 1), (128, 20, 144, 1)), ((128, 20, 12, 12), (128, 20, 12, 12)), ((128, 50, 16, 2), (128, 50, 16, 2)), ((128, 50, 16, 1), (128, 50, 16, 1)), ((128, 50, 4, 4), (128, 50, 4, 4)), ((128, 500), (128, 500)), ((128, 10), (128, 10))] ,
'matmul': [((128, 800), (800, 500)), ((128, 500), (500, 10))] ,
'fss_eq': [128] ,
},
"alexnet-cifar10-128": {
'conv2d': [((128, 3, 32, 32), (96, 3, 11, 11), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 4, (10, 10), (1, 1), 1))), ((128, 96, 5, 5), (256, 96, 5, 5), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 256, 1, 1), (384, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 384, 1, 1), (384, 384, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 384, 1, 1), (256, 384, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'fss_comp': [1228800, 614400, 307200, 307200, 307200, 131072, 65536, 32768, 32768, 32768, 49152, 49152, 32768, 32768, 32768, 1280, 11520, 1280] ,
'mul': [((128, 96, 25, 4), (128, 96, 25, 4)), ((128, 96, 25, 2), (128, 96, 25, 2)), ((128, 96, 25, 1), (128, 96, 25, 1)), ((128, 96, 25, 1), (128, 96, 25, 1)), ((128, 96, 5, 5), (128, 96, 5, 5)), ((96,), (3200, 96)), ((3200, 96), (96,)), ((128, 256, 1, 4), (128, 256, 1, 4)), ((128, 256, 1, 2), (128, 256, 1, 2)), ((128, 256, 1, 1), (128, 256, 1, 1)), ((128, 256, 1, 1), (128, 256, 1, 1)), ((128, 256, 1, 1), (128, 256, 1, 1)), ((256,), (128, 256)), ((128, 256), (256,)), ((128, 384, 1, 1), (128, 384, 1, 1)), ((128, 384, 1, 1), (128, 384, 1, 1)), ((128, 256, 1, 1), (128, 256, 1, 1)), ((128, 256), (128, 256)), ((128, 256), (128, 256)), ((128, 10), (128, 10))] ,
'matmul': [((128, 256), (256, 256)), ((128, 256), (256, 256)), ((128, 256), (256, 10))] ,
'fss_eq': [128] ,
},
"alexnet-tiny-imagenet-8": {
'conv2d': [((8, 3, 64, 64), (64, 3, 11, 11), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 4, (2, 2), (1, 1), 1))), ((8, 64, 7, 7), (192, 64, 5, 5), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (2, 2), (1, 1), 1))), ((8, 192, 3, 3), (384, 192, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 384, 3, 3), (256, 384, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 256, 3, 3), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'fss_comp': [100352, 50176, 25088, 25088, 25088, 55296, 27648, 13824, 13824, 13824, 27648, 18432, 8192, 4096, 2048, 2048, 2048, 8192, 8192, 318400, 1600] ,
'mul': [((8, 64, 49, 4), (8, 64, 49, 4)), ((8, 64, 49, 2), (8, 64, 49, 2)), ((8, 64, 49, 1), (8, 64, 49, 1)), ((8, 64, 49, 1), (8, 64, 49, 1)), ((8, 64, 7, 7), (8, 64, 7, 7)), ((8, 192, 9, 4), (8, 192, 9, 4)), ((8, 192, 9, 2), (8, 192, 9, 2)), ((8, 192, 9, 1), (8, 192, 9, 1)), ((8, 192, 9, 1), (8, 192, 9, 1)), ((8, 192, 3, 3), (8, 192, 3, 3)), ((8, 384, 3, 3), (8, 384, 3, 3)), ((8, 256, 3, 3), (8, 256, 3, 3)), ((8, 256, 1, 4), (8, 256, 1, 4)), ((8, 256, 1, 2), (8, 256, 1, 2)), ((8, 256, 1, 1), (8, 256, 1, 1)), ((8, 256, 1, 1), (8, 256, 1, 1)), ((8, 256, 1, 1), (8, 256, 1, 1)), ((8, 1024), (8, 1024)), ((8, 1024), (8, 1024))] ,
'matmul': [((8, 256), (256, 1024)), ((8, 1024), (1024, 1024)), ((8, 1024), (1024, 200))] ,
'fss_eq': [8] ,
},
"alexnet-tiny-imagenet-16": {
'conv2d': [((16, 3, 64, 64), (64, 3, 11, 11), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 4, (2, 2), (1, 1), 1))), ((16, 64, 7, 7), (192, 64, 5, 5), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (2, 2), (1, 1), 1))), ((16, 192, 3, 3), (384, 192, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 384, 3, 3), (256, 384, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 256, 3, 3), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'fss_comp': [200704, 100352, 50176, 50176, 50176, 110592, 55296, 27648, 27648, 27648, 55296, 36864, 16384, 8192, 4096, 4096, 4096, 16384, 16384, 636800, 3200] ,
'mul': [((16, 64, 49, 4), (16, 64, 49, 4)), ((16, 64, 49, 2), (16, 64, 49, 2)), ((16, 64, 49, 1), (16, 64, 49, 1)), ((16, 64, 49, 1), (16, 64, 49, 1)), ((16, 64, 7, 7), (16, 64, 7, 7)), ((16, 192, 9, 4), (16, 192, 9, 4)), ((16, 192, 9, 2), (16, 192, 9, 2)), ((16, 192, 9, 1), (16, 192, 9, 1)), ((16, 192, 9, 1), (16, 192, 9, 1)), ((16, 192, 3, 3), (16, 192, 3, 3)), ((16, 384, 3, 3), (16, 384, 3, 3)), ((16, 256, 3, 3), (16, 256, 3, 3)), ((16, 256, 1, 4), (16, 256, 1, 4)), ((16, 256, 1, 2), (16, 256, 1, 2)), ((16, 256, 1, 1), (16, 256, 1, 1)), ((16, 256, 1, 1), (16, 256, 1, 1)), ((16, 256, 1, 1), (16, 256, 1, 1)), ((16, 1024), (16, 1024)), ((16, 1024), (16, 1024))] ,
'matmul': [((16, 256), (256, 1024)), ((16, 1024), (1024, 1024)), ((16, 1024), (1024, 200))] ,
'fss_eq': [16] ,
},
"alexnet-tiny-imagenet-128": {
'conv2d': [((128, 3, 64, 64), (64, 3, 11, 11), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 4, (2, 2), (1, 1), 1))), ((128, 64, 7, 7), (192, 64, 5, 5), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (2, 2), (1, 1), 1))), ((128, 192, 3, 3), (384, 192, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 384, 3, 3), (256, 384, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 256, 3, 3), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'fss_comp': [1605632, 802816, 401408, 401408, 401408, 884736, 442368, 221184, 221184, 221184, 442368, 294912, 131072, 65536, 32768, 32768, 32768, 131072, 131072, 5094400, 25600] ,
'mul': [((128, 64, 49, 4), (128, 64, 49, 4)), ((128, 64, 49, 2), (128, 64, 49, 2)), ((128, 64, 49, 1), (128, 64, 49, 1)), ((128, 64, 49, 1), (128, 64, 49, 1)), ((128, 64, 7, 7), (128, 64, 7, 7)), ((128, 192, 9, 4), (128, 192, 9, 4)), ((128, 192, 9, 2), (128, 192, 9, 2)), ((128, 192, 9, 1), (128, 192, 9, 1)), ((128, 192, 9, 1), (128, 192, 9, 1)), ((128, 192, 3, 3), (128, 192, 3, 3)), ((128, 384, 3, 3), (128, 384, 3, 3)), ((128, 256, 3, 3), (128, 256, 3, 3)), ((128, 256, 1, 4), (128, 256, 1, 4)), ((128, 256, 1, 2), (128, 256, 1, 2)), ((128, 256, 1, 1), (128, 256, 1, 1)), ((128, 256, 1, 1), (128, 256, 1, 1)), ((128, 256, 1, 1), (128, 256, 1, 1)), ((128, 1024), (128, 1024)), ((128, 1024), (128, 1024))] ,
'matmul': [((128, 256), (256, 1024)), ((128, 1024), (1024, 1024)), ((128, 1024), (1024, 200))] ,
'fss_eq': [128] ,
},
"vgg16-cifar10-1": {
'conv2d': [((1, 3, 32, 32), (64, 3, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 64, 32, 32), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 64, 16, 16), (128, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 128, 16, 16), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 128, 8, 8), (256, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 256, 8, 8), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 256, 8, 8), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 256, 4, 4), (512, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'fss_comp': [65536, 32768, 16384, 16384, 32768, 16384, 8192, 8192, 16384, 16384, 8192, 4096, 4096, 8192, 8192, 4096, 2048, 2048, 2048, 2048, 1024, 512, 512, 4096, 4096, 90, 10] ,
'mul': [((1, 64, 32, 32), (1, 64, 32, 32)), ((1, 64, 256, 2), (1, 64, 256, 2)), ((1, 64, 256, 1), (1, 64, 256, 1)), ((1, 64, 16, 16), (1, 64, 16, 16)), ((1, 128, 16, 16), (1, 128, 16, 16)), ((1, 128, 64, 2), (1, 128, 64, 2)), ((1, 128, 64, 1), (1, 128, 64, 1)), ((1, 128, 8, 8), (1, 128, 8, 8)), ((1, 256, 8, 8), (1, 256, 8, 8)), ((1, 256, 8, 8), (1, 256, 8, 8)), ((1, 256, 16, 2), (1, 256, 16, 2)), ((1, 256, 16, 1), (1, 256, 16, 1)), ((1, 256, 4, 4), (1, 256, 4, 4)), ((1, 512, 4, 4), (1, 512, 4, 4)), ((1, 512, 4, 4), (1, 512, 4, 4)), ((1, 512, 4, 2), (1, 512, 4, 2)), ((1, 512, 4, 1), (1, 512, 4, 1)), ((1, 512, 2, 2), (1, 512, 2, 2)), ((1, 512, 2, 2), (1, 512, 2, 2)), ((1, 512, 2, 2), (1, 512, 2, 2)), ((1, 512, 1, 2), (1, 512, 1, 2)), ((1, 512, 1, 1), (1, 512, 1, 1)), ((1, 512, 1, 1), (1, 512, 1, 1)), ((1, 4096), (1, 4096)), ((1, 4096), (1, 4096))] ,
'matmul': [((1, 512), (512, 4096)), ((1, 4096), (4096, 4096)), ((1, 4096), (4096, 10))] ,
'fss_eq': [1] ,
},
"vgg16-cifar10-8": {
'conv2d': [((8, 3, 32, 32), (64, 3, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 64, 32, 32), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 64, 16, 16), (128, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 128, 16, 16), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 128, 8, 8), (256, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 256, 8, 8), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 256, 8, 8), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 256, 4, 4), (512, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'fss_comp': [524288, 262144, 131072, 131072, 262144, 131072, 65536, 65536, 131072, 131072, 65536, 32768, 32768, 65536, 65536, 32768, 16384, 16384, 16384, 16384, 8192, 4096, 4096, 32768, 32768, 720, 80] ,
'mul': [((8, 64, 32, 32), (8, 64, 32, 32)), ((8, 64, 256, 2), (8, 64, 256, 2)), ((8, 64, 256, 1), (8, 64, 256, 1)), ((8, 64, 16, 16), (8, 64, 16, 16)), ((8, 128, 16, 16), (8, 128, 16, 16)), ((8, 128, 64, 2), (8, 128, 64, 2)), ((8, 128, 64, 1), (8, 128, 64, 1)), ((8, 128, 8, 8), (8, 128, 8, 8)), ((8, 256, 8, 8), (8, 256, 8, 8)), ((8, 256, 8, 8), (8, 256, 8, 8)), ((8, 256, 16, 2), (8, 256, 16, 2)), ((8, 256, 16, 1), (8, 256, 16, 1)), ((8, 256, 4, 4), (8, 256, 4, 4)), ((8, 512, 4, 4), (8, 512, 4, 4)), ((8, 512, 4, 4), (8, 512, 4, 4)), ((8, 512, 4, 2), (8, 512, 4, 2)), ((8, 512, 4, 1), (8, 512, 4, 1)), ((8, 512, 2, 2), (8, 512, 2, 2)), ((8, 512, 2, 2), (8, 512, 2, 2)), ((8, 512, 2, 2), (8, 512, 2, 2)), ((8, 512, 1, 2), (8, 512, 1, 2)), ((8, 512, 1, 1), (8, 512, 1, 1)), ((8, 512, 1, 1), (8, 512, 1, 1)), ((8, 4096), (8, 4096)), ((8, 4096), (8, 4096))] ,
'matmul': [((8, 512), (512, 4096)), ((8, 4096), (4096, 4096)), ((8, 4096), (4096, 10))] ,
'fss_eq': [8] ,
},
"vgg16-cifar10-16": {
'conv2d': [((16, 3, 32, 32), (64, 3, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 64, 32, 32), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 64, 16, 16), (128, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 128, 16, 16), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 128, 8, 8), (256, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 256, 8, 8), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 256, 8, 8), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 256, 4, 4), (512, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'fss_comp': [1048576, 524288, 262144, 262144, 524288, 262144, 131072, 131072, 262144, 262144, 131072, 65536, 65536, 131072, 131072, 65536, 32768, 32768, 32768, 32768, 16384, 8192, 8192, 65536, 65536, 1440, 160] ,
'mul': [((16, 64, 32, 32), (16, 64, 32, 32)), ((16, 64, 256, 2), (16, 64, 256, 2)), ((16, 64, 256, 1), (16, 64, 256, 1)), ((16, 64, 16, 16), (16, 64, 16, 16)), ((16, 128, 16, 16), (16, 128, 16, 16)), ((16, 128, 64, 2), (16, 128, 64, 2)), ((16, 128, 64, 1), (16, 128, 64, 1)), ((16, 128, 8, 8), (16, 128, 8, 8)), ((16, 256, 8, 8), (16, 256, 8, 8)), ((16, 256, 8, 8), (16, 256, 8, 8)), ((16, 256, 16, 2), (16, 256, 16, 2)), ((16, 256, 16, 1), (16, 256, 16, 1)), ((16, 256, 4, 4), (16, 256, 4, 4)), ((16, 512, 4, 4), (16, 512, 4, 4)), ((16, 512, 4, 4), (16, 512, 4, 4)), ((16, 512, 4, 2), (16, 512, 4, 2)), ((16, 512, 4, 1), (16, 512, 4, 1)), ((16, 512, 2, 2), (16, 512, 2, 2)), ((16, 512, 2, 2), (16, 512, 2, 2)), ((16, 512, 2, 2), (16, 512, 2, 2)), ((16, 512, 1, 2), (16, 512, 1, 2)), ((16, 512, 1, 1), (16, 512, 1, 1)), ((16, 512, 1, 1), (16, 512, 1, 1)), ((16, 4096), (16, 4096)), ((16, 4096), (16, 4096))] ,
'matmul': [((16, 512), (512, 4096)), ((16, 4096), (4096, 4096)), ((16, 4096), (4096, 10))] ,
'fss_eq': [16] ,
},
"vgg16-cifar10-64": {
'conv2d': [((64, 3, 32, 32), (64, 3, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((64, 64, 32, 32), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((64, 64, 16, 16), (128, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((64, 128, 16, 16), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((64, 128, 8, 8), (256, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((64, 256, 8, 8), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((64, 256, 8, 8), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((64, 256, 4, 4), (512, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((64, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((64, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((64, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((64, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((64, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'fss_comp': [4194304, 2097152, 1048576, 1048576, 2097152, 1048576, 524288, 524288, 1048576, 1048576, 524288, 262144, 262144, 524288, 524288, 262144, 131072, 131072, 131072, 131072, 65536, 32768, 32768, 262144, 262144, 5760, 640] ,
'mul': [((64, 64, 32, 32), (64, 64, 32, 32)), ((64, 64, 256, 2), (64, 64, 256, 2)), ((64, 64, 256, 1), (64, 64, 256, 1)), ((64, 64, 16, 16), (64, 64, 16, 16)), ((64, 128, 16, 16), (64, 128, 16, 16)), ((64, 128, 64, 2), (64, 128, 64, 2)), ((64, 128, 64, 1), (64, 128, 64, 1)), ((64, 128, 8, 8), (64, 128, 8, 8)), ((64, 256, 8, 8), (64, 256, 8, 8)), ((64, 256, 8, 8), (64, 256, 8, 8)), ((64, 256, 16, 2), (64, 256, 16, 2)), ((64, 256, 16, 1), (64, 256, 16, 1)), ((64, 256, 4, 4), (64, 256, 4, 4)), ((64, 512, 4, 4), (64, 512, 4, 4)), ((64, 512, 4, 4), (64, 512, 4, 4)), ((64, 512, 4, 2), (64, 512, 4, 2)), ((64, 512, 4, 1), (64, 512, 4, 1)), ((64, 512, 2, 2), (64, 512, 2, 2)), ((64, 512, 2, 2), (64, 512, 2, 2)), ((64, 512, 2, 2), (64, 512, 2, 2)), ((64, 512, 1, 2), (64, 512, 1, 2)), ((64, 512, 1, 1), (64, 512, 1, 1)), ((64, 512, 1, 1), (64, 512, 1, 1)), ((64, 4096), (64, 4096)), ((64, 4096), (64, 4096))] ,
'matmul': [((64, 512), (512, 4096)), ((64, 4096), (4096, 4096)), ((64, 4096), (4096, 10))] ,
'fss_eq': [64] ,
},
"vgg16-cifar10-128": {
'conv2d': [((128, 3, 32, 32), (64, 3, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 64, 32, 32), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 64, 16, 16), (128, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 128, 16, 16), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 128, 8, 8), (256, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 256, 8, 8), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 256, 8, 8), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 256, 4, 4), (512, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((128, 512, 2, 2), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'fss_comp': [8388608, 4194304, 2097152, 2097152, 4194304, 2097152, 1048576, 1048576, 2097152, 2097152, 1048576, 524288, 524288, 1048576, 1048576, 524288, 262144, 262144, 262144, 262144, 131072, 65536, 65536, 524288, 524288, 11520, 1280] ,
'mul': [((128, 64, 32, 32), (128, 64, 32, 32)), ((128, 64, 256, 2), (128, 64, 256, 2)), ((128, 64, 256, 1), (128, 64, 256, 1)), ((128, 64, 16, 16), (128, 64, 16, 16)), ((128, 128, 16, 16), (128, 128, 16, 16)), ((128, 128, 64, 2), (128, 128, 64, 2)), ((128, 128, 64, 1), (128, 128, 64, 1)), ((128, 128, 8, 8), (128, 128, 8, 8)), ((128, 256, 8, 8), (128, 256, 8, 8)), ((128, 256, 8, 8), (128, 256, 8, 8)), ((128, 256, 16, 2), (128, 256, 16, 2)), ((128, 256, 16, 1), (128, 256, 16, 1)), ((128, 256, 4, 4), (128, 256, 4, 4)), ((128, 512, 4, 4), (128, 512, 4, 4)), ((128, 512, 4, 4), (128, 512, 4, 4)), ((128, 512, 4, 2), (128, 512, 4, 2)), ((128, 512, 4, 1), (128, 512, 4, 1)), ((128, 512, 2, 2), (128, 512, 2, 2)), ((128, 512, 2, 2), (128, 512, 2, 2)), ((128, 512, 2, 2), (128, 512, 2, 2)), ((128, 512, 1, 2), (128, 512, 1, 2)), ((128, 512, 1, 1), (128, 512, 1, 1)), ((128, 512, 1, 1), (128, 512, 1, 1)), ((128, 4096), (128, 4096)), ((128, 4096), (128, 4096))] ,
'matmul': [((128, 512), (512, 4096)), ((128, 4096), (4096, 4096)), ((128, 4096), (4096, 10))] ,
'fss_eq': [128] ,
},
"vgg16-tiny-imagenet-1": {
'conv2d': [((1, 3, 64, 64), (64, 3, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 64, 64, 64), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 64, 32, 32), (128, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 128, 32, 32), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 128, 16, 16), (256, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 256, 16, 16), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 256, 16, 16), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 256, 8, 8), (512, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 512, 8, 8), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 512, 8, 8), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'fss_comp': [262144, 131072, 65536, 65536, 131072, 65536, 32768, 32768, 65536, 65536, 32768, 16384, 16384, 32768, 32768, 16384, 8192, 8192, 8192, 8192, 4096, 2048, 2048, 4096, 4096, 39800, 200] ,
'mul': [((1, 64, 64, 64), (1, 64, 64, 64)), ((1, 64, 1024, 2), (1, 64, 1024, 2)), ((1, 64, 1024, 1), (1, 64, 1024, 1)), ((1, 64, 32, 32), (1, 64, 32, 32)), ((1, 128, 32, 32), (1, 128, 32, 32)), ((1, 128, 256, 2), (1, 128, 256, 2)), ((1, 128, 256, 1), (1, 128, 256, 1)), ((1, 128, 16, 16), (1, 128, 16, 16)), ((1, 256, 16, 16), (1, 256, 16, 16)), ((1, 256, 16, 16), (1, 256, 16, 16)), ((1, 256, 64, 2), (1, 256, 64, 2)), ((1, 256, 64, 1), (1, 256, 64, 1)), ((1, 256, 8, 8), (1, 256, 8, 8)), ((1, 512, 8, 8), (1, 512, 8, 8)), ((1, 512, 8, 8), (1, 512, 8, 8)), ((1, 512, 16, 2), (1, 512, 16, 2)), ((1, 512, 16, 1), (1, 512, 16, 1)), ((1, 512, 4, 4), (1, 512, 4, 4)), ((1, 512, 4, 4), (1, 512, 4, 4)), ((1, 512, 4, 4), (1, 512, 4, 4)), ((1, 512, 4, 2), (1, 512, 4, 2)), ((1, 512, 4, 1), (1, 512, 4, 1)), ((1, 512, 2, 2), (1, 512, 2, 2)), ((1, 4096), (1, 4096)), ((1, 4096), (1, 4096))] ,
'matmul': [((1, 2048), (2048, 4096)), ((1, 4096), (4096, 4096)), ((1, 4096), (4096, 200))] ,
'fss_eq': [1] ,
},
"vgg16-tiny-imagenet-16": {
'conv2d': [((16, 3, 64, 64), (64, 3, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 64, 64, 64), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 64, 32, 32), (128, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 128, 32, 32), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 128, 16, 16), (256, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 256, 16, 16), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 256, 16, 16), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 256, 8, 8), (512, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 512, 8, 8), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 512, 8, 8), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'fss_comp': [4194304, 2097152, 1048576, 1048576, 2097152, 1048576, 524288, 524288, 1048576, 1048576, 524288, 262144, 262144, 524288, 524288, 262144, 131072, 131072, 131072, 131072, 65536, 32768, 32768, 65536, 65536, 636800, 3200] ,
'mul': [((16, 64, 64, 64), (16, 64, 64, 64)), ((16, 64, 1024, 2), (16, 64, 1024, 2)), ((16, 64, 1024, 1), (16, 64, 1024, 1)), ((16, 64, 32, 32), (16, 64, 32, 32)), ((16, 128, 32, 32), (16, 128, 32, 32)), ((16, 128, 256, 2), (16, 128, 256, 2)), ((16, 128, 256, 1), (16, 128, 256, 1)), ((16, 128, 16, 16), (16, 128, 16, 16)), ((16, 256, 16, 16), (16, 256, 16, 16)), ((16, 256, 16, 16), (16, 256, 16, 16)), ((16, 256, 64, 2), (16, 256, 64, 2)), ((16, 256, 64, 1), (16, 256, 64, 1)), ((16, 256, 8, 8), (16, 256, 8, 8)), ((16, 512, 8, 8), (16, 512, 8, 8)), ((16, 512, 8, 8), (16, 512, 8, 8)), ((16, 512, 16, 2), (16, 512, 16, 2)), ((16, 512, 16, 1), (16, 512, 16, 1)), ((16, 512, 4, 4), (16, 512, 4, 4)), ((16, 512, 4, 4), (16, 512, 4, 4)), ((16, 512, 4, 4), (16, 512, 4, 4)), ((16, 512, 4, 2), (16, 512, 4, 2)), ((16, 512, 4, 1), (16, 512, 4, 1)), ((16, 512, 2, 2), (16, 512, 2, 2)), ((16, 4096), (16, 4096)), ((16, 4096), (16, 4096))] ,
'matmul': [((16, 2048), (2048, 4096)), ((16, 4096), (4096, 4096)), ((16, 4096), (4096, 200))] ,
'fss_eq': [16] ,
},
"vgg16-tiny-imagenet-32": {
'conv2d': [((32, 3, 64, 64), (64, 3, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((32, 64, 64, 64), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((32, 64, 32, 32), (128, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((32, 128, 32, 32), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((32, 128, 16, 16), (256, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((32, 256, 16, 16), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((32, 256, 16, 16), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((32, 256, 8, 8), (512, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((32, 512, 8, 8), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((32, 512, 8, 8), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((32, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((32, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((32, 512, 4, 4), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'fss_comp': [8388608, 4194304, 2097152, 2097152, 4194304, 2097152, 1048576, 1048576, 2097152, 2097152, 1048576, 524288, 524288, 1048576, 1048576, 524288, 262144, 262144, 262144, 262144, 131072, 65536, 65536, 131072, 131072, 1273600, 6400] ,
'mul': [((32, 64, 64, 64), (32, 64, 64, 64)), ((32, 64, 1024, 2), (32, 64, 1024, 2)), ((32, 64, 1024, 1), (32, 64, 1024, 1)), ((32, 64, 32, 32), (32, 64, 32, 32)), ((32, 128, 32, 32), (32, 128, 32, 32)), ((32, 128, 256, 2), (32, 128, 256, 2)), ((32, 128, 256, 1), (32, 128, 256, 1)), ((32, 128, 16, 16), (32, 128, 16, 16)), ((32, 256, 16, 16), (32, 256, 16, 16)), ((32, 256, 16, 16), (32, 256, 16, 16)), ((32, 256, 64, 2), (32, 256, 64, 2)), ((32, 256, 64, 1), (32, 256, 64, 1)), ((32, 256, 8, 8), (32, 256, 8, 8)), ((32, 512, 8, 8), (32, 512, 8, 8)), ((32, 512, 8, 8), (32, 512, 8, 8)), ((32, 512, 16, 2), (32, 512, 16, 2)), ((32, 512, 16, 1), (32, 512, 16, 1)), ((32, 512, 4, 4), (32, 512, 4, 4)), ((32, 512, 4, 4), (32, 512, 4, 4)), ((32, 512, 4, 4), (32, 512, 4, 4)), ((32, 512, 4, 2), (32, 512, 4, 2)), ((32, 512, 4, 1), (32, 512, 4, 1)), ((32, 512, 2, 2), (32, 512, 2, 2)), ((32, 4096), (32, 4096)), ((32, 4096), (32, 4096))] ,
'matmul': [((32, 2048), (2048, 4096)), ((32, 4096), (4096, 4096)), ((32, 4096), (4096, 200))] ,
'fss_eq': [32] ,
},
"resnet18-hymenoptera-1": {
'conv2d': [((1, 3, 224, 224), (64, 3, 7, 7), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (3, 3), (1, 1), 1))), ((1, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 64, 56, 56), (128, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((1, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 64, 56, 56), (128, 64, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((1, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 128, 28, 28), (256, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((1, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 128, 28, 28), (256, 128, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((1, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 256, 14, 14), (512, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((1, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 256, 14, 14), (512, 256, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((1, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((1, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'mul': [((64,), (12544, 64)), ((12544, 64), (64,)), ((1, 64, 3136, 4), (1, 64, 3136, 4)), ((1, 64, 3136, 2), (1, 64, 3136, 2)), ((1, 64, 3136, 1), (1, 64, 3136, 1)), ((1, 64, 3136, 1), (1, 64, 3136, 1)), ((1, 64, 56, 56), (1, 64, 56, 56)), ((64,), (3136, 64)), ((3136, 64), (64,)), ((1, 64, 56, 56), (1, 64, 56, 56)), ((64,), (3136, 64)), ((3136, 64), (64,)), ((1, 64, 56, 56), (1, 64, 56, 56)), ((64,), (3136, 64)), ((3136, 64), (64,)), ((1, 64, 56, 56), (1, 64, 56, 56)), ((64,), (3136, 64)), ((3136, 64), (64,)), ((1, 64, 56, 56), (1, 64, 56, 56)), ((128,), (784, 128)), ((784, 128), (128,)), ((1, 128, 28, 28), (1, 128, 28, 28)), ((128,), (784, 128)), ((784, 128), (128,)), ((128,), (784, 128)), ((784, 128), (128,)), ((1, 128, 28, 28), (1, 128, 28, 28)), ((128,), (784, 128)), ((784, 128), (128,)), ((1, 128, 28, 28), (1, 128, 28, 28)), ((128,), (784, 128)), ((784, 128), (128,)), ((1, 128, 28, 28), (1, 128, 28, 28)), ((256,), (196, 256)), ((196, 256), (256,)), ((1, 256, 14, 14), (1, 256, 14, 14)), ((256,), (196, 256)), ((196, 256), (256,)), ((256,), (196, 256)), ((196, 256), (256,)), ((1, 256, 14, 14), (1, 256, 14, 14)), ((256,), (196, 256)), ((196, 256), (256,)), ((1, 256, 14, 14), (1, 256, 14, 14)), ((256,), (196, 256)), ((196, 256), (256,)), ((1, 256, 14, 14), (1, 256, 14, 14)), ((512,), (49, 512)), ((49, 512), (512,)), ((1, 512, 7, 7), (1, 512, 7, 7)), ((512,), (49, 512)), ((49, 512), (512,)), ((512,), (49, 512)), ((49, 512), (512,)), ((1, 512, 7, 7), (1, 512, 7, 7)), ((512,), (49, 512)), ((49, 512), (512,)), ((1, 512, 7, 7), (1, 512, 7, 7)), ((512,), (49, 512)), ((49, 512), (512,)), ((1, 512, 7, 7), (1, 512, 7, 7))] ,
'fss_comp': [802816, 401408, 200704, 200704, 200704, 200704, 200704, 200704, 200704, 100352, 100352, 100352, 100352, 50176, 50176, 50176, 50176, 25088, 25088, 25088, 25088, 2, 2] ,
'matmul': [((1, 512), (512, 2))] ,
'fss_eq': [1] ,
},
"resnet18-hymenoptera-2": {
'conv2d': [((2, 3, 224, 224), (64, 3, 7, 7), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (3, 3), (1, 1), 1))), ((2, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((2, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((2, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((2, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((2, 64, 56, 56), (128, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((2, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((2, 64, 56, 56), (128, 64, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((2, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((2, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((2, 128, 28, 28), (256, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((2, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((2, 128, 28, 28), (256, 128, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((2, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((2, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((2, 256, 14, 14), (512, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((2, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((2, 256, 14, 14), (512, 256, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((2, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((2, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'mul': [((64,), (25088, 64)), ((25088, 64), (64,)), ((2, 64, 3136, 4), (2, 64, 3136, 4)), ((2, 64, 3136, 2), (2, 64, 3136, 2)), ((2, 64, 3136, 1), (2, 64, 3136, 1)), ((2, 64, 3136, 1), (2, 64, 3136, 1)), ((2, 64, 56, 56), (2, 64, 56, 56)), ((64,), (6272, 64)), ((6272, 64), (64,)), ((2, 64, 56, 56), (2, 64, 56, 56)), ((64,), (6272, 64)), ((6272, 64), (64,)), ((2, 64, 56, 56), (2, 64, 56, 56)), ((64,), (6272, 64)), ((6272, 64), (64,)), ((2, 64, 56, 56), (2, 64, 56, 56)), ((64,), (6272, 64)), ((6272, 64), (64,)), ((2, 64, 56, 56), (2, 64, 56, 56)), ((128,), (1568, 128)), ((1568, 128), (128,)), ((2, 128, 28, 28), (2, 128, 28, 28)), ((128,), (1568, 128)), ((1568, 128), (128,)), ((128,), (1568, 128)), ((1568, 128), (128,)), ((2, 128, 28, 28), (2, 128, 28, 28)), ((128,), (1568, 128)), ((1568, 128), (128,)), ((2, 128, 28, 28), (2, 128, 28, 28)), ((128,), (1568, 128)), ((1568, 128), (128,)), ((2, 128, 28, 28), (2, 128, 28, 28)), ((256,), (392, 256)), ((392, 256), (256,)), ((2, 256, 14, 14), (2, 256, 14, 14)), ((256,), (392, 256)), ((392, 256), (256,)), ((256,), (392, 256)), ((392, 256), (256,)), ((2, 256, 14, 14), (2, 256, 14, 14)), ((256,), (392, 256)), ((392, 256), (256,)), ((2, 256, 14, 14), (2, 256, 14, 14)), ((256,), (392, 256)), ((392, 256), (256,)), ((2, 256, 14, 14), (2, 256, 14, 14)), ((512,), (98, 512)), ((98, 512), (512,)), ((2, 512, 7, 7), (2, 512, 7, 7)), ((512,), (98, 512)), ((98, 512), (512,)), ((512,), (98, 512)), ((98, 512), (512,)), ((2, 512, 7, 7), (2, 512, 7, 7)), ((512,), (98, 512)), ((98, 512), (512,)), ((2, 512, 7, 7), (2, 512, 7, 7)), ((512,), (98, 512)), ((98, 512), (512,)), ((2, 512, 7, 7), (2, 512, 7, 7))] ,
'fss_comp': [1605632, 802816, 401408, 401408, 401408, 401408, 401408, 401408, 401408, 200704, 200704, 200704, 200704, 100352, 100352, 100352, 100352, 50176, 50176, 50176, 50176, 4, 4] ,
'matmul': [((2, 512), (512, 2))] ,
'fss_eq': [2] ,
},
"resnet18-hymenoptera-4": {
'conv2d': [((4, 3, 224, 224), (64, 3, 7, 7), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (3, 3), (1, 1), 1))), ((4, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((4, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((4, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((4, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((4, 64, 56, 56), (128, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((4, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((4, 64, 56, 56), (128, 64, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((4, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((4, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((4, 128, 28, 28), (256, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((4, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((4, 128, 28, 28), (256, 128, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((4, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((4, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((4, 256, 14, 14), (512, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((4, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((4, 256, 14, 14), (512, 256, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((4, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((4, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'mul': [((64,), (50176, 64)), ((50176, 64), (64,)), ((4, 64, 3136, 4), (4, 64, 3136, 4)), ((4, 64, 3136, 2), (4, 64, 3136, 2)), ((4, 64, 3136, 1), (4, 64, 3136, 1)), ((4, 64, 3136, 1), (4, 64, 3136, 1)), ((4, 64, 56, 56), (4, 64, 56, 56)), ((64,), (12544, 64)), ((12544, 64), (64,)), ((4, 64, 56, 56), (4, 64, 56, 56)), ((64,), (12544, 64)), ((12544, 64), (64,)), ((4, 64, 56, 56), (4, 64, 56, 56)), ((64,), (12544, 64)), ((12544, 64), (64,)), ((4, 64, 56, 56), (4, 64, 56, 56)), ((64,), (12544, 64)), ((12544, 64), (64,)), ((4, 64, 56, 56), (4, 64, 56, 56)), ((128,), (3136, 128)), ((3136, 128), (128,)), ((4, 128, 28, 28), (4, 128, 28, 28)), ((128,), (3136, 128)), ((3136, 128), (128,)), ((128,), (3136, 128)), ((3136, 128), (128,)), ((4, 128, 28, 28), (4, 128, 28, 28)), ((128,), (3136, 128)), ((3136, 128), (128,)), ((4, 128, 28, 28), (4, 128, 28, 28)), ((128,), (3136, 128)), ((3136, 128), (128,)), ((4, 128, 28, 28), (4, 128, 28, 28)), ((256,), (784, 256)), ((784, 256), (256,)), ((4, 256, 14, 14), (4, 256, 14, 14)), ((256,), (784, 256)), ((784, 256), (256,)), ((256,), (784, 256)), ((784, 256), (256,)), ((4, 256, 14, 14), (4, 256, 14, 14)), ((256,), (784, 256)), ((784, 256), (256,)), ((4, 256, 14, 14), (4, 256, 14, 14)), ((256,), (784, 256)), ((784, 256), (256,)), ((4, 256, 14, 14), (4, 256, 14, 14)), ((512,), (196, 512)), ((196, 512), (512,)), ((4, 512, 7, 7), (4, 512, 7, 7)), ((512,), (196, 512)), ((196, 512), (512,)), ((512,), (196, 512)), ((196, 512), (512,)), ((4, 512, 7, 7), (4, 512, 7, 7)), ((512,), (196, 512)), ((196, 512), (512,)), ((4, 512, 7, 7), (4, 512, 7, 7)), ((512,), (196, 512)), ((196, 512), (512,)), ((4, 512, 7, 7), (4, 512, 7, 7))] ,
'fss_comp': [3211264, 1605632, 802816, 802816, 802816, 802816, 802816, 802816, 802816, 401408, 401408, 401408, 401408, 200704, 200704, 200704, 200704, 100352, 100352, 100352, 100352, 8, 8] ,
'matmul': [((4, 512), (512, 2))] ,
'fss_eq': [4] ,
},
"resnet18-hymenoptera-8": {
'conv2d': [((8, 3, 224, 224), (64, 3, 7, 7), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (3, 3), (1, 1), 1))), ((8, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 64, 56, 56), (128, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((8, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 64, 56, 56), (128, 64, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((8, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 128, 28, 28), (256, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((8, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 128, 28, 28), (256, 128, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((8, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 256, 14, 14), (512, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((8, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 256, 14, 14), (512, 256, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((8, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((8, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'mul': [((64,), (100352, 64)), ((100352, 64), (64,)), ((8, 64, 3136, 4), (8, 64, 3136, 4)), ((8, 64, 3136, 2), (8, 64, 3136, 2)), ((8, 64, 3136, 1), (8, 64, 3136, 1)), ((8, 64, 3136, 1), (8, 64, 3136, 1)), ((8, 64, 56, 56), (8, 64, 56, 56)), ((64,), (25088, 64)), ((25088, 64), (64,)), ((8, 64, 56, 56), (8, 64, 56, 56)), ((64,), (25088, 64)), ((25088, 64), (64,)), ((8, 64, 56, 56), (8, 64, 56, 56)), ((64,), (25088, 64)), ((25088, 64), (64,)), ((8, 64, 56, 56), (8, 64, 56, 56)), ((64,), (25088, 64)), ((25088, 64), (64,)), ((8, 64, 56, 56), (8, 64, 56, 56)), ((128,), (6272, 128)), ((6272, 128), (128,)), ((8, 128, 28, 28), (8, 128, 28, 28)), ((128,), (6272, 128)), ((6272, 128), (128,)), ((128,), (6272, 128)), ((6272, 128), (128,)), ((8, 128, 28, 28), (8, 128, 28, 28)), ((128,), (6272, 128)), ((6272, 128), (128,)), ((8, 128, 28, 28), (8, 128, 28, 28)), ((128,), (6272, 128)), ((6272, 128), (128,)), ((8, 128, 28, 28), (8, 128, 28, 28)), ((256,), (1568, 256)), ((1568, 256), (256,)), ((8, 256, 14, 14), (8, 256, 14, 14)), ((256,), (1568, 256)), ((1568, 256), (256,)), ((256,), (1568, 256)), ((1568, 256), (256,)), ((8, 256, 14, 14), (8, 256, 14, 14)), ((256,), (1568, 256)), ((1568, 256), (256,)), ((8, 256, 14, 14), (8, 256, 14, 14)), ((256,), (1568, 256)), ((1568, 256), (256,)), ((8, 256, 14, 14), (8, 256, 14, 14)), ((512,), (392, 512)), ((392, 512), (512,)), ((8, 512, 7, 7), (8, 512, 7, 7)), ((512,), (392, 512)), ((392, 512), (512,)), ((512,), (392, 512)), ((392, 512), (512,)), ((8, 512, 7, 7), (8, 512, 7, 7)), ((512,), (392, 512)), ((392, 512), (512,)), ((8, 512, 7, 7), (8, 512, 7, 7)), ((512,), (392, 512)), ((392, 512), (512,)), ((8, 512, 7, 7), (8, 512, 7, 7))] ,
'fss_comp': [6422528, 3211264, 1605632, 1605632, 1605632, 1605632, 1605632, 1605632, 1605632, 802816, 802816, 802816, 802816, 401408, 401408, 401408, 401408, 200704, 200704, 200704, 200704, 16, 16] ,
'matmul': [((8, 512), (512, 2))] ,
'fss_eq': [8] ,
},
"resnet18-hymenoptera-16": {
'conv2d': [((16, 3, 224, 224), (64, 3, 7, 7), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (3, 3), (1, 1), 1))), ((16, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 64, 56, 56), (64, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 64, 56, 56), (128, 64, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((16, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 64, 56, 56), (128, 64, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((16, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 128, 28, 28), (128, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 128, 28, 28), (256, 128, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((16, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 128, 28, 28), (256, 128, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((16, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 256, 14, 14), (256, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 256, 14, 14), (512, 256, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (1, 1), (1, 1), 1))), ((16, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 256, 14, 14), (512, 256, 1, 1), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 2, (0, 0), (1, 1), 1))), ((16, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1))), ((16, 512, 7, 7), (512, 512, 3, 3), (('bias', 'stride', 'padding', 'dilation', 'groups'), (None, 1, (1, 1), (1, 1), 1)))] ,
'mul': [((64,), (200704, 64)), ((200704, 64), (64,)), ((16, 64, 3136, 4), (16, 64, 3136, 4)), ((16, 64, 3136, 2), (16, 64, 3136, 2)), ((16, 64, 3136, 1), (16, 64, 3136, 1)), ((16, 64, 3136, 1), (16, 64, 3136, 1)), ((16, 64, 56, 56), (16, 64, 56, 56)), ((64,), (50176, 64)), ((50176, 64), (64,)), ((16, 64, 56, 56), (16, 64, 56, 56)), ((64,), (50176, 64)), ((50176, 64), (64,)), ((16, 64, 56, 56), (16, 64, 56, 56)), ((64,), (50176, 64)), ((50176, 64), (64,)), ((16, 64, 56, 56), (16, 64, 56, 56)), ((64,), (50176, 64)), ((50176, 64), (64,)), ((16, 64, 56, 56), (16, 64, 56, 56)), ((128,), (12544, 128)), ((12544, 128), (128,)), ((16, 128, 28, 28), (16, 128, 28, 28)), ((128,), (12544, 128)), ((12544, 128), (128,)), ((128,), (12544, 128)), ((12544, 128), (128,)), ((16, 128, 28, 28), (16, 128, 28, 28)), ((128,), (12544, 128)), ((12544, 128), (128,)), ((16, 128, 28, 28), (16, 128, 28, 28)), ((128,), (12544, 128)), ((12544, 128), (128,)), ((16, 128, 28, 28), (16, 128, 28, 28)), ((256,), (3136, 256)), ((3136, 256), (256,)), ((16, 256, 14, 14), (16, 256, 14, 14)), ((256,), (3136, 256)), ((3136, 256), (256,)), ((256,), (3136, 256)), ((3136, 256), (256,)), ((16, 256, 14, 14), (16, 256, 14, 14)), ((256,), (3136, 256)), ((3136, 256), (256,)), ((16, 256, 14, 14), (16, 256, 14, 14)), ((256,), (3136, 256)), ((3136, 256), (256,)), ((16, 256, 14, 14), (16, 256, 14, 14)), ((512,), (784, 512)), ((784, 512), (512,)), ((16, 512, 7, 7), (16, 512, 7, 7)), ((512,), (784, 512)), ((784, 512), (512,)), ((512,), (784, 512)), ((784, 512), (512,)), ((16, 512, 7, 7), (16, 512, 7, 7)), ((512,), (784, 512)), ((784, 512), (512,)), ((16, 512, 7, 7), (16, 512, 7, 7)), ((512,), (784, 512)), ((784, 512), (512,)), ((16, 512, 7, 7), (16, 512, 7, 7))] ,
'fss_comp': [12845056, 6422528, 3211264, 3211264, 3211264, 3211264, 3211264, 3211264, 3211264, 1605632, 1605632, 1605632, 1605632, 802816, 802816, 802816, 802816, 401408, 401408, 401408, 401408, 32, 32] ,
'matmul': [((16, 512), (512, 2))] ,
'fss_eq': [16] ,
}
}
# fmt: on
def build_prepocessing(model, dataset, batch_size, workers, args):
start_time = time.time()
try:
config = config_zoo[f"{model}-{dataset}-{batch_size}"]
except KeyError:
print(f"WARNING: No preprocessing found for {model}-{dataset}-{batch_size}")
return 0
if args.comm_info:
sy.comm_total = 0
if args.verbose:
print("Preprocess")
for op in ["fss_eq", "fss_comp"]:
n_instances_list = config[op]
for n_instances in n_instances_list:
if args.verbose:
print(f"{op} n_instances", n_instances)
sy.local_worker.crypto_store.provide_primitives(
op=op, kwargs_={}, workers=workers, n_instances=n_instances
)
for op in {"mul", "matmul", "conv2d"}:
try:
shapes = config[op]
except KeyError:
continue
if args.verbose:
print(f"{op} shapes", shapes)
if args.dtype == "int":
torch_dtype = th.int32
field = 2 ** 32
elif args.dtype == "long":
torch_dtype = th.int64
field = 2 ** 64
else:
raise ValueError(f"Unsupported dtype {args.dtype}")
if op == "conv2d":
for left_shape, right_shape, hashable_kwargs_ in shapes:
keys, values = hashable_kwargs_
kwargs_ = dict(zip(keys, values))
sy.local_worker.crypto_store.provide_primitives(
op=op,
kwargs_=kwargs_,
workers=workers,
n_instances=1,
shapes=(left_shape, right_shape),
dtype=args.dtype,
torch_dtype=torch_dtype,
field=field,
)
else:
sy.local_worker.crypto_store.provide_primitives(
op=op,
kwargs_={},
workers=workers,
n_instances=1,
shapes=shapes,
dtype=args.dtype,
torch_dtype=torch_dtype,
field=field,
)
preprocess_time = time.time() - start_time
if args.comm_info:
print(
"Total communication per item",
round(sy.comm_total / args.batch_size / 10 ** 6, 3),
"MB",
)
del sy.comm_total
if args.verbose:
print(
"...",
preprocess_time,
"s",
"[time per item=",
preprocess_time / args.batch_size,
"]",
)
else:
print("Preprocessing time (s):\t", round(preprocess_time / args.batch_size, 4))
return preprocess_time