10
10
IRVariable ,
11
11
flip_comparison_opcode ,
12
12
)
13
- from vyper .venom .passes .base_pass import IRPass
13
+ from vyper .venom .passes .base_pass import InstUpdater , IRPass
14
14
15
15
TRUTHY_INSTRUCTIONS = ("iszero" , "jnz" , "assert" , "assert_unreachable" )
16
16
@@ -19,62 +19,6 @@ def lit_eq(op: IROperand, val: int) -> bool:
19
19
return isinstance (op , IRLiteral ) and wrap256 (op .value ) == wrap256 (val )
20
20
21
21
22
- class InstructionUpdater :
23
- """
24
- A helper class for updating instructions which also updates the
25
- basic block and dfg in place
26
- """
27
-
28
- def __init__ (self , dfg : DFGAnalysis ):
29
- self .dfg = dfg
30
-
31
- def _update_operands (self , inst : IRInstruction , replace_dict : dict [IROperand , IROperand ]):
32
- old_operands = inst .operands
33
- new_operands = [replace_dict [op ] if op in replace_dict else op for op in old_operands ]
34
- self ._update (inst , inst .opcode , new_operands )
35
-
36
- def _update (self , inst : IRInstruction , opcode : str , new_operands : list [IROperand ]):
37
- assert opcode != "phi"
38
- # sanity
39
- assert all (isinstance (op , IROperand ) for op in new_operands )
40
-
41
- old_operands = inst .operands
42
-
43
- for op in old_operands :
44
- if not isinstance (op , IRVariable ):
45
- continue
46
- uses = self .dfg .get_uses (op )
47
- if inst in uses :
48
- uses .remove (inst )
49
-
50
- for op in new_operands :
51
- if isinstance (op , IRVariable ):
52
- self .dfg .add_use (op , inst )
53
-
54
- inst .opcode = opcode
55
- inst .operands = new_operands
56
-
57
- def _store (self , inst : IRInstruction , op : IROperand ):
58
- self ._update (inst , "store" , [op ])
59
-
60
- def _add_before (self , inst : IRInstruction , opcode : str , args : list [IROperand ]) -> IRVariable :
61
- """
62
- Insert another instruction before the given instruction
63
- """
64
- assert opcode != "phi"
65
- index = inst .parent .instructions .index (inst )
66
- var = inst .parent .parent .get_next_variable ()
67
- operands = list (args )
68
- new_inst = IRInstruction (opcode , operands , output = var )
69
- inst .parent .insert_instruction (new_inst , index )
70
- for op in new_inst .operands :
71
- if isinstance (op , IRVariable ):
72
- self .dfg .add_use (op , new_inst )
73
- self .dfg .add_use (var , inst )
74
- self .dfg .set_producing_instruction (var , new_inst )
75
- return var
76
-
77
-
78
22
class AlgebraicOptimizationPass (IRPass ):
79
23
"""
80
24
This pass reduces algebraic evaluatable expressions.
@@ -86,18 +30,17 @@ class AlgebraicOptimizationPass(IRPass):
86
30
"""
87
31
88
32
dfg : DFGAnalysis
89
- updater : InstructionUpdater
33
+ updater : InstUpdater
90
34
91
35
def run_pass (self ):
92
36
self .dfg = self .analyses_cache .request_analysis (DFGAnalysis ) # type: ignore
93
- self .updater = InstructionUpdater (self .dfg )
37
+ self .updater = InstUpdater (self .dfg )
94
38
self ._handle_offset ()
95
39
96
40
self ._algebraic_opt ()
97
41
self ._optimize_iszero_chains ()
98
42
self ._algebraic_opt ()
99
43
100
- self .analyses_cache .invalidate_analysis (DFGAnalysis )
101
44
self .analyses_cache .invalidate_analysis (LivenessAnalysis )
102
45
103
46
def _optimize_iszero_chains (self ) -> None :
@@ -132,7 +75,7 @@ def _optimize_iszero_chains(self) -> None:
132
75
continue
133
76
134
77
out_var = iszero_chain [keep_count ].operands [0 ]
135
- self .updater ._update_operands (use_inst , {inst .output : out_var })
78
+ self .updater .update_operands (use_inst , {inst .output : out_var })
136
79
137
80
def _get_iszero_chain (self , op : IROperand ) -> list [IRInstruction ]:
138
81
chain : list [IRInstruction ] = []
@@ -207,30 +150,30 @@ def _handle_inst_peephole(self, inst: IRInstruction):
207
150
if inst .opcode in {"shl" , "shr" , "sar" }:
208
151
# (x >> 0) == (x << 0) == x
209
152
if lit_eq (operands [1 ], 0 ):
210
- self .updater ._store (inst , operands [0 ])
153
+ self .updater .store (inst , operands [0 ])
211
154
return
212
155
# no more cases for these instructions
213
156
return
214
157
215
158
if inst .opcode == "exp" :
216
159
# x ** 0 -> 1
217
160
if lit_eq (operands [0 ], 0 ):
218
- self .updater ._store (inst , IRLiteral (1 ))
161
+ self .updater .store (inst , IRLiteral (1 ))
219
162
return
220
163
221
164
# 1 ** x -> 1
222
165
if lit_eq (operands [1 ], 1 ):
223
- self .updater ._store (inst , IRLiteral (1 ))
166
+ self .updater .store (inst , IRLiteral (1 ))
224
167
return
225
168
226
169
# 0 ** x -> iszero x
227
170
if lit_eq (operands [1 ], 0 ):
228
- self .updater ._update (inst , "iszero" , [operands [0 ]])
171
+ self .updater .update (inst , "iszero" , [operands [0 ]])
229
172
return
230
173
231
174
# x ** 1 -> x
232
175
if lit_eq (operands [0 ], 1 ):
233
- self .updater ._store (inst , operands [1 ])
176
+ self .updater .store (inst , operands [1 ])
234
177
return
235
178
236
179
# no more cases for this instruction
@@ -239,64 +182,64 @@ def _handle_inst_peephole(self, inst: IRInstruction):
239
182
if inst .opcode in {"add" , "sub" , "xor" }:
240
183
# (x - x) == (x ^ x) == 0
241
184
if inst .opcode in ("xor" , "sub" ) and operands [0 ] == operands [1 ]:
242
- self .updater ._store (inst , IRLiteral (0 ))
185
+ self .updater .store (inst , IRLiteral (0 ))
243
186
return
244
187
245
188
# (x + 0) == (0 + x) -> x
246
189
# x - 0 -> x
247
190
# (x ^ 0) == (0 ^ x) -> x
248
191
if lit_eq (operands [0 ], 0 ):
249
- self .updater ._store (inst , operands [1 ])
192
+ self .updater .store (inst , operands [1 ])
250
193
return
251
194
252
195
# (-1) - x -> ~x
253
196
# from two's complement
254
197
if inst .opcode == "sub" and lit_eq (operands [1 ], - 1 ):
255
- self .updater ._update (inst , "not" , [operands [0 ]])
198
+ self .updater .update (inst , "not" , [operands [0 ]])
256
199
return
257
200
258
201
# x ^ 0xFFFF..FF -> ~x
259
202
if inst .opcode == "xor" and lit_eq (operands [0 ], - 1 ):
260
- self .updater ._update (inst , "not" , [operands [1 ]])
203
+ self .updater .update (inst , "not" , [operands [1 ]])
261
204
return
262
205
263
206
return
264
207
265
208
# x & 0xFF..FF -> x
266
209
if inst .opcode == "and" and lit_eq (operands [0 ], - 1 ):
267
- self .updater ._store (inst , operands [1 ])
210
+ self .updater .store (inst , operands [1 ])
268
211
return
269
212
270
213
if inst .opcode in ("mul" , "and" , "div" , "sdiv" , "mod" , "smod" ):
271
214
# (x * 0) == (x & 0) == (x // 0) == (x % 0) -> 0
272
215
if any (lit_eq (op , 0 ) for op in operands ):
273
- self .updater ._store (inst , IRLiteral (0 ))
216
+ self .updater .store (inst , IRLiteral (0 ))
274
217
return
275
218
276
219
if inst .opcode in {"mul" , "div" , "sdiv" , "mod" , "smod" }:
277
220
if inst .opcode in ("mod" , "smod" ) and lit_eq (operands [0 ], 1 ):
278
221
# x % 1 -> 0
279
- self .updater ._store (inst , IRLiteral (0 ))
222
+ self .updater .store (inst , IRLiteral (0 ))
280
223
return
281
224
282
225
# (x * 1) == (1 * x) == (x // 1) -> x
283
226
if inst .opcode in ("mul" , "div" , "sdiv" ) and lit_eq (operands [0 ], 1 ):
284
- self .updater ._store (inst , operands [1 ])
227
+ self .updater .store (inst , operands [1 ])
285
228
return
286
229
287
230
if self ._is_lit (operands [0 ]) and is_power_of_two (operands [0 ].value ):
288
231
val = operands [0 ].value
289
232
# x % (2^n) -> x & (2^n - 1)
290
233
if inst .opcode == "mod" :
291
- self .updater ._update (inst , "and" , [IRLiteral (val - 1 ), operands [1 ]])
234
+ self .updater .update (inst , "and" , [IRLiteral (val - 1 ), operands [1 ]])
292
235
return
293
236
# x / (2^n) -> x >> n
294
237
if inst .opcode == "div" :
295
- self .updater ._update (inst , "shr" , [operands [1 ], IRLiteral (int_log2 (val ))])
238
+ self .updater .update (inst , "shr" , [operands [1 ], IRLiteral (int_log2 (val ))])
296
239
return
297
240
# x * (2^n) -> x << n
298
241
if inst .opcode == "mul" :
299
- self .updater ._update (inst , "shl" , [operands [1 ], IRLiteral (int_log2 (val ))])
242
+ self .updater .update (inst , "shl" , [operands [1 ], IRLiteral (int_log2 (val ))])
300
243
return
301
244
return
302
245
@@ -313,42 +256,42 @@ def _handle_inst_peephole(self, inst: IRInstruction):
313
256
if inst .opcode == "or" :
314
257
# x | 0xff..ff == 0xff..ff
315
258
if any (lit_eq (op , SizeLimits .MAX_UINT256 ) for op in operands ):
316
- self .updater ._store (inst , IRLiteral (SizeLimits .MAX_UINT256 ))
259
+ self .updater .store (inst , IRLiteral (SizeLimits .MAX_UINT256 ))
317
260
return
318
261
319
262
# x | n -> 1 in truthy positions (if n is non zero)
320
263
if is_truthy and self ._is_lit (operands [0 ]) and operands [0 ].value != 0 :
321
- self .updater ._store (inst , IRLiteral (1 ))
264
+ self .updater .store (inst , IRLiteral (1 ))
322
265
return
323
266
324
267
# x | 0 -> x
325
268
if lit_eq (operands [0 ], 0 ):
326
- self .updater ._store (inst , operands [1 ])
269
+ self .updater .store (inst , operands [1 ])
327
270
return
328
271
329
272
if inst .opcode == "eq" :
330
273
# x == x -> 1
331
274
if operands [0 ] == operands [1 ]:
332
- self .updater ._store (inst , IRLiteral (1 ))
275
+ self .updater .store (inst , IRLiteral (1 ))
333
276
return
334
277
335
278
# x == 0 -> iszero x
336
279
if lit_eq (operands [0 ], 0 ):
337
- self .updater ._update (inst , "iszero" , [operands [1 ]])
280
+ self .updater .update (inst , "iszero" , [operands [1 ]])
338
281
return
339
282
340
283
# eq x -1 -> iszero(~x)
341
284
# (saves codesize, not gas)
342
285
if lit_eq (operands [0 ], - 1 ):
343
- var = self .updater ._add_before (inst , "not" , [operands [1 ]])
344
- self .updater ._update (inst , "iszero" , [var ])
286
+ var = self .updater .add_before (inst , "not" , [operands [1 ]])
287
+ self .updater .update (inst , "iszero" , [var ])
345
288
return
346
289
347
290
if prefer_iszero :
348
291
# (eq x y) has the same truthyness as (iszero (xor x y))
349
- tmp = self .updater ._add_before (inst , "xor" , [operands [0 ], operands [1 ]])
292
+ tmp = self .updater .add_before (inst , "xor" , [operands [0 ], operands [1 ]])
350
293
351
- self .updater ._update (inst , "iszero" , [tmp ])
294
+ self .updater .update (inst , "iszero" , [tmp ])
352
295
return
353
296
354
297
if inst .opcode in COMPARATOR_INSTRUCTIONS :
@@ -361,7 +304,7 @@ def _optimize_comparator_instruction(self, inst, prefer_iszero):
361
304
362
305
# (x > x) == (x < x) -> 0
363
306
if operands [0 ] == operands [1 ]:
364
- self .updater ._store (inst , IRLiteral (0 ))
307
+ self .updater .store (inst , IRLiteral (0 ))
365
308
return
366
309
367
310
is_gt = "g" in opcode
@@ -388,31 +331,28 @@ def _optimize_comparator_instruction(self, inst, prefer_iszero):
388
331
almost_never = lo + 1
389
332
390
333
if lit_eq (operands [0 ], never ):
391
- self .updater ._store (inst , IRLiteral (0 ))
334
+ self .updater .store (inst , IRLiteral (0 ))
392
335
return
393
336
394
337
if lit_eq (operands [0 ], almost_never ):
395
338
# (lt x 1), (gt x (MAX_UINT256 - 1)), (slt x (MIN_INT256 + 1))
396
339
397
- # correct optimization:
398
- self .updater ._update (inst , "eq" , [operands [1 ], IRLiteral (never )])
399
- # canary:
400
- # self.updater._update(inst, "eq", [operands[1], IRLiteral(lo)])
340
+ self .updater .update (inst , "eq" , [operands [1 ], IRLiteral (never )])
401
341
return
402
342
403
343
# rewrites. in positions where iszero is preferred, (gt x 5) => (ge x 6)
404
344
if prefer_iszero and lit_eq (operands [0 ], almost_always ):
405
345
# e.g. gt x 0, slt x MAX_INT256
406
- tmp = self .updater ._add_before (inst , "eq" , operands )
407
- self .updater ._update (inst , "iszero" , [tmp ])
346
+ tmp = self .updater .add_before (inst , "eq" , operands )
347
+ self .updater .update (inst , "iszero" , [tmp ])
408
348
return
409
349
410
350
# since push0 was introduced in shanghai, it's potentially
411
351
# better to actually reverse this optimization -- i.e.
412
352
# replace iszero(iszero(x)) with (gt x 0)
413
353
if opcode == "gt" and lit_eq (operands [0 ], 0 ):
414
- tmp = self .updater ._add_before (inst , "iszero" , [operands [1 ]])
415
- self .updater ._update (inst , "iszero" , [tmp ])
354
+ tmp = self .updater .add_before (inst , "iszero" , [operands [1 ]])
355
+ self .updater .update (inst , "iszero" , [tmp ])
416
356
return
417
357
418
358
# rewrite comparisons by removing an `iszero`, e.g.
@@ -448,7 +388,7 @@ def _optimize_comparator_instruction(self, inst, prefer_iszero):
448
388
449
389
new_opcode = flip_comparison_opcode (opcode )
450
390
451
- self .updater ._update (inst , new_opcode , [IRLiteral (val ), operands [1 ]])
391
+ self .updater .update (inst , new_opcode , [IRLiteral (val ), operands [1 ]])
452
392
453
393
assert len (after .operands ) == 1
454
- self .updater ._update (after , "store" , after .operands )
394
+ self .updater .update (after , "store" , after .operands )
0 commit comments