Skip to content

Commit 41108f9

Browse files
refactor[venom]: use InstUpdater in more passes (#4508)
move `InstUpdater` to its own file and use it in more passes - it's safer and faster for passes which use the `dfg` to modify it in place using `InstUpdater`, rather than modifying instructions (without updating the `dfg`), and then relying on a stale `dfg`.
1 parent ded6b2c commit 41108f9

9 files changed

+141
-138
lines changed

vyper/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,11 @@ def addmany(self, iterable):
6767
def remove(self, item: _T) -> None:
6868
del self._data[item]
6969

70-
def drop(self, item: _T):
70+
def discard(self, item: _T):
7171
# friendly version of remove
7272
self._data.pop(item, None)
7373

74+
# consider renaming to "discardmany"
7475
def dropmany(self, iterable):
7576
for item in iterable:
7677
self._data.pop(item, None)

vyper/venom/passes/algebraic_optimization.py

+38-98
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
IRVariable,
1111
flip_comparison_opcode,
1212
)
13-
from vyper.venom.passes.base_pass import IRPass
13+
from vyper.venom.passes.base_pass import InstUpdater, IRPass
1414

1515
TRUTHY_INSTRUCTIONS = ("iszero", "jnz", "assert", "assert_unreachable")
1616

@@ -19,62 +19,6 @@ def lit_eq(op: IROperand, val: int) -> bool:
1919
return isinstance(op, IRLiteral) and wrap256(op.value) == wrap256(val)
2020

2121

22-
class InstructionUpdater:
23-
"""
24-
A helper class for updating instructions which also updates the
25-
basic block and dfg in place
26-
"""
27-
28-
def __init__(self, dfg: DFGAnalysis):
29-
self.dfg = dfg
30-
31-
def _update_operands(self, inst: IRInstruction, replace_dict: dict[IROperand, IROperand]):
32-
old_operands = inst.operands
33-
new_operands = [replace_dict[op] if op in replace_dict else op for op in old_operands]
34-
self._update(inst, inst.opcode, new_operands)
35-
36-
def _update(self, inst: IRInstruction, opcode: str, new_operands: list[IROperand]):
37-
assert opcode != "phi"
38-
# sanity
39-
assert all(isinstance(op, IROperand) for op in new_operands)
40-
41-
old_operands = inst.operands
42-
43-
for op in old_operands:
44-
if not isinstance(op, IRVariable):
45-
continue
46-
uses = self.dfg.get_uses(op)
47-
if inst in uses:
48-
uses.remove(inst)
49-
50-
for op in new_operands:
51-
if isinstance(op, IRVariable):
52-
self.dfg.add_use(op, inst)
53-
54-
inst.opcode = opcode
55-
inst.operands = new_operands
56-
57-
def _store(self, inst: IRInstruction, op: IROperand):
58-
self._update(inst, "store", [op])
59-
60-
def _add_before(self, inst: IRInstruction, opcode: str, args: list[IROperand]) -> IRVariable:
61-
"""
62-
Insert another instruction before the given instruction
63-
"""
64-
assert opcode != "phi"
65-
index = inst.parent.instructions.index(inst)
66-
var = inst.parent.parent.get_next_variable()
67-
operands = list(args)
68-
new_inst = IRInstruction(opcode, operands, output=var)
69-
inst.parent.insert_instruction(new_inst, index)
70-
for op in new_inst.operands:
71-
if isinstance(op, IRVariable):
72-
self.dfg.add_use(op, new_inst)
73-
self.dfg.add_use(var, inst)
74-
self.dfg.set_producing_instruction(var, new_inst)
75-
return var
76-
77-
7822
class AlgebraicOptimizationPass(IRPass):
7923
"""
8024
This pass reduces algebraic evaluatable expressions.
@@ -86,18 +30,17 @@ class AlgebraicOptimizationPass(IRPass):
8630
"""
8731

8832
dfg: DFGAnalysis
89-
updater: InstructionUpdater
33+
updater: InstUpdater
9034

9135
def run_pass(self):
9236
self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) # type: ignore
93-
self.updater = InstructionUpdater(self.dfg)
37+
self.updater = InstUpdater(self.dfg)
9438
self._handle_offset()
9539

9640
self._algebraic_opt()
9741
self._optimize_iszero_chains()
9842
self._algebraic_opt()
9943

100-
self.analyses_cache.invalidate_analysis(DFGAnalysis)
10144
self.analyses_cache.invalidate_analysis(LivenessAnalysis)
10245

10346
def _optimize_iszero_chains(self) -> None:
@@ -132,7 +75,7 @@ def _optimize_iszero_chains(self) -> None:
13275
continue
13376

13477
out_var = iszero_chain[keep_count].operands[0]
135-
self.updater._update_operands(use_inst, {inst.output: out_var})
78+
self.updater.update_operands(use_inst, {inst.output: out_var})
13679

13780
def _get_iszero_chain(self, op: IROperand) -> list[IRInstruction]:
13881
chain: list[IRInstruction] = []
@@ -207,30 +150,30 @@ def _handle_inst_peephole(self, inst: IRInstruction):
207150
if inst.opcode in {"shl", "shr", "sar"}:
208151
# (x >> 0) == (x << 0) == x
209152
if lit_eq(operands[1], 0):
210-
self.updater._store(inst, operands[0])
153+
self.updater.store(inst, operands[0])
211154
return
212155
# no more cases for these instructions
213156
return
214157

215158
if inst.opcode == "exp":
216159
# x ** 0 -> 1
217160
if lit_eq(operands[0], 0):
218-
self.updater._store(inst, IRLiteral(1))
161+
self.updater.store(inst, IRLiteral(1))
219162
return
220163

221164
# 1 ** x -> 1
222165
if lit_eq(operands[1], 1):
223-
self.updater._store(inst, IRLiteral(1))
166+
self.updater.store(inst, IRLiteral(1))
224167
return
225168

226169
# 0 ** x -> iszero x
227170
if lit_eq(operands[1], 0):
228-
self.updater._update(inst, "iszero", [operands[0]])
171+
self.updater.update(inst, "iszero", [operands[0]])
229172
return
230173

231174
# x ** 1 -> x
232175
if lit_eq(operands[0], 1):
233-
self.updater._store(inst, operands[1])
176+
self.updater.store(inst, operands[1])
234177
return
235178

236179
# no more cases for this instruction
@@ -239,64 +182,64 @@ def _handle_inst_peephole(self, inst: IRInstruction):
239182
if inst.opcode in {"add", "sub", "xor"}:
240183
# (x - x) == (x ^ x) == 0
241184
if inst.opcode in ("xor", "sub") and operands[0] == operands[1]:
242-
self.updater._store(inst, IRLiteral(0))
185+
self.updater.store(inst, IRLiteral(0))
243186
return
244187

245188
# (x + 0) == (0 + x) -> x
246189
# x - 0 -> x
247190
# (x ^ 0) == (0 ^ x) -> x
248191
if lit_eq(operands[0], 0):
249-
self.updater._store(inst, operands[1])
192+
self.updater.store(inst, operands[1])
250193
return
251194

252195
# (-1) - x -> ~x
253196
# from two's complement
254197
if inst.opcode == "sub" and lit_eq(operands[1], -1):
255-
self.updater._update(inst, "not", [operands[0]])
198+
self.updater.update(inst, "not", [operands[0]])
256199
return
257200

258201
# x ^ 0xFFFF..FF -> ~x
259202
if inst.opcode == "xor" and lit_eq(operands[0], -1):
260-
self.updater._update(inst, "not", [operands[1]])
203+
self.updater.update(inst, "not", [operands[1]])
261204
return
262205

263206
return
264207

265208
# x & 0xFF..FF -> x
266209
if inst.opcode == "and" and lit_eq(operands[0], -1):
267-
self.updater._store(inst, operands[1])
210+
self.updater.store(inst, operands[1])
268211
return
269212

270213
if inst.opcode in ("mul", "and", "div", "sdiv", "mod", "smod"):
271214
# (x * 0) == (x & 0) == (x // 0) == (x % 0) -> 0
272215
if any(lit_eq(op, 0) for op in operands):
273-
self.updater._store(inst, IRLiteral(0))
216+
self.updater.store(inst, IRLiteral(0))
274217
return
275218

276219
if inst.opcode in {"mul", "div", "sdiv", "mod", "smod"}:
277220
if inst.opcode in ("mod", "smod") and lit_eq(operands[0], 1):
278221
# x % 1 -> 0
279-
self.updater._store(inst, IRLiteral(0))
222+
self.updater.store(inst, IRLiteral(0))
280223
return
281224

282225
# (x * 1) == (1 * x) == (x // 1) -> x
283226
if inst.opcode in ("mul", "div", "sdiv") and lit_eq(operands[0], 1):
284-
self.updater._store(inst, operands[1])
227+
self.updater.store(inst, operands[1])
285228
return
286229

287230
if self._is_lit(operands[0]) and is_power_of_two(operands[0].value):
288231
val = operands[0].value
289232
# x % (2^n) -> x & (2^n - 1)
290233
if inst.opcode == "mod":
291-
self.updater._update(inst, "and", [IRLiteral(val - 1), operands[1]])
234+
self.updater.update(inst, "and", [IRLiteral(val - 1), operands[1]])
292235
return
293236
# x / (2^n) -> x >> n
294237
if inst.opcode == "div":
295-
self.updater._update(inst, "shr", [operands[1], IRLiteral(int_log2(val))])
238+
self.updater.update(inst, "shr", [operands[1], IRLiteral(int_log2(val))])
296239
return
297240
# x * (2^n) -> x << n
298241
if inst.opcode == "mul":
299-
self.updater._update(inst, "shl", [operands[1], IRLiteral(int_log2(val))])
242+
self.updater.update(inst, "shl", [operands[1], IRLiteral(int_log2(val))])
300243
return
301244
return
302245

@@ -313,42 +256,42 @@ def _handle_inst_peephole(self, inst: IRInstruction):
313256
if inst.opcode == "or":
314257
# x | 0xff..ff == 0xff..ff
315258
if any(lit_eq(op, SizeLimits.MAX_UINT256) for op in operands):
316-
self.updater._store(inst, IRLiteral(SizeLimits.MAX_UINT256))
259+
self.updater.store(inst, IRLiteral(SizeLimits.MAX_UINT256))
317260
return
318261

319262
# x | n -> 1 in truthy positions (if n is non zero)
320263
if is_truthy and self._is_lit(operands[0]) and operands[0].value != 0:
321-
self.updater._store(inst, IRLiteral(1))
264+
self.updater.store(inst, IRLiteral(1))
322265
return
323266

324267
# x | 0 -> x
325268
if lit_eq(operands[0], 0):
326-
self.updater._store(inst, operands[1])
269+
self.updater.store(inst, operands[1])
327270
return
328271

329272
if inst.opcode == "eq":
330273
# x == x -> 1
331274
if operands[0] == operands[1]:
332-
self.updater._store(inst, IRLiteral(1))
275+
self.updater.store(inst, IRLiteral(1))
333276
return
334277

335278
# x == 0 -> iszero x
336279
if lit_eq(operands[0], 0):
337-
self.updater._update(inst, "iszero", [operands[1]])
280+
self.updater.update(inst, "iszero", [operands[1]])
338281
return
339282

340283
# eq x -1 -> iszero(~x)
341284
# (saves codesize, not gas)
342285
if lit_eq(operands[0], -1):
343-
var = self.updater._add_before(inst, "not", [operands[1]])
344-
self.updater._update(inst, "iszero", [var])
286+
var = self.updater.add_before(inst, "not", [operands[1]])
287+
self.updater.update(inst, "iszero", [var])
345288
return
346289

347290
if prefer_iszero:
348291
# (eq x y) has the same truthyness as (iszero (xor x y))
349-
tmp = self.updater._add_before(inst, "xor", [operands[0], operands[1]])
292+
tmp = self.updater.add_before(inst, "xor", [operands[0], operands[1]])
350293

351-
self.updater._update(inst, "iszero", [tmp])
294+
self.updater.update(inst, "iszero", [tmp])
352295
return
353296

354297
if inst.opcode in COMPARATOR_INSTRUCTIONS:
@@ -361,7 +304,7 @@ def _optimize_comparator_instruction(self, inst, prefer_iszero):
361304

362305
# (x > x) == (x < x) -> 0
363306
if operands[0] == operands[1]:
364-
self.updater._store(inst, IRLiteral(0))
307+
self.updater.store(inst, IRLiteral(0))
365308
return
366309

367310
is_gt = "g" in opcode
@@ -388,31 +331,28 @@ def _optimize_comparator_instruction(self, inst, prefer_iszero):
388331
almost_never = lo + 1
389332

390333
if lit_eq(operands[0], never):
391-
self.updater._store(inst, IRLiteral(0))
334+
self.updater.store(inst, IRLiteral(0))
392335
return
393336

394337
if lit_eq(operands[0], almost_never):
395338
# (lt x 1), (gt x (MAX_UINT256 - 1)), (slt x (MIN_INT256 + 1))
396339

397-
# correct optimization:
398-
self.updater._update(inst, "eq", [operands[1], IRLiteral(never)])
399-
# canary:
400-
# self.updater._update(inst, "eq", [operands[1], IRLiteral(lo)])
340+
self.updater.update(inst, "eq", [operands[1], IRLiteral(never)])
401341
return
402342

403343
# rewrites. in positions where iszero is preferred, (gt x 5) => (ge x 6)
404344
if prefer_iszero and lit_eq(operands[0], almost_always):
405345
# e.g. gt x 0, slt x MAX_INT256
406-
tmp = self.updater._add_before(inst, "eq", operands)
407-
self.updater._update(inst, "iszero", [tmp])
346+
tmp = self.updater.add_before(inst, "eq", operands)
347+
self.updater.update(inst, "iszero", [tmp])
408348
return
409349

410350
# since push0 was introduced in shanghai, it's potentially
411351
# better to actually reverse this optimization -- i.e.
412352
# replace iszero(iszero(x)) with (gt x 0)
413353
if opcode == "gt" and lit_eq(operands[0], 0):
414-
tmp = self.updater._add_before(inst, "iszero", [operands[1]])
415-
self.updater._update(inst, "iszero", [tmp])
354+
tmp = self.updater.add_before(inst, "iszero", [operands[1]])
355+
self.updater.update(inst, "iszero", [tmp])
416356
return
417357

418358
# rewrite comparisons by removing an `iszero`, e.g.
@@ -448,7 +388,7 @@ def _optimize_comparator_instruction(self, inst, prefer_iszero):
448388

449389
new_opcode = flip_comparison_opcode(opcode)
450390

451-
self.updater._update(inst, new_opcode, [IRLiteral(val), operands[1]])
391+
self.updater.update(inst, new_opcode, [IRLiteral(val), operands[1]])
452392

453393
assert len(after.operands) == 1
454-
self.updater._update(after, "store", after.operands)
394+
self.updater.update(after, "store", after.operands)

vyper/venom/passes/base_pass.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from vyper.venom.analysis import IRAnalysesCache
22
from vyper.venom.context import IRContext
33
from vyper.venom.function import IRFunction
4+
from vyper.venom.passes.machinery.inst_updater import InstUpdater
45

56

67
class IRPass:
@@ -10,6 +11,7 @@ class IRPass:
1011

1112
function: IRFunction
1213
analyses_cache: IRAnalysesCache
14+
updater: InstUpdater # optional, does not need to be instantiated
1315

1416
def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction):
1517
self.function = function

vyper/venom/passes/dft.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def run_pass(self) -> None:
2121
self.data_offspring = {}
2222
self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet()
2323

24-
self.dfg = self.analyses_cache.request_analysis(DFGAnalysis)
24+
self.dfg = self.analyses_cache.force_analysis(DFGAnalysis)
2525

2626
for bb in self.function.get_basic_blocks():
2727
self._process_basic_block(bb)

vyper/venom/passes/load_elimination.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis
44
from vyper.venom.basicblock import IRLiteral
55
from vyper.venom.effects import Effects
6-
from vyper.venom.passes.base_pass import IRPass
6+
from vyper.venom.passes.base_pass import InstUpdater, IRPass
77

88

99
def _conflict(store_opcode: str, k1: IRLiteral, k2: IRLiteral):
@@ -23,7 +23,11 @@ class LoadElimination(IRPass):
2323

2424
# should this be renamed to EffectsElimination?
2525

26+
updater: InstUpdater
27+
2628
def run_pass(self):
29+
self.updater = InstUpdater(self.analyses_cache.request_analysis(DFGAnalysis))
30+
2731
for bb in self.function.get_basic_blocks():
2832
self._process_bb(bb, Effects.MEMORY, "mload", "mstore")
2933
self._process_bb(bb, Effects.TRANSIENT, "tload", "tstore")
@@ -85,7 +89,7 @@ def _handle_store(self, inst, store_opcode):
8589
# we found a redundant store, eliminate it
8690
existing_val = self._lattice.get(known_ptr)
8791
if self.equivalent(val, existing_val):
88-
inst.make_nop()
92+
self.updater.nop(inst)
8993
return
9094

9195
self._lattice[known_ptr] = val

0 commit comments

Comments
 (0)