From 94fd4d0b3602cb79162ab268627e590e627defe2 Mon Sep 17 00:00:00 2001 From: Guruprasad Kamath Date: Wed, 26 Jun 2024 10:08:12 +0200 Subject: [PATCH] Implement EIP-4200 --- src/ethereum/prague/vm/eof.py | 106 +++++++++++++++-- src/ethereum/prague/vm/gas.py | 3 + .../prague/vm/instructions/__init__.py | 16 ++- .../prague/vm/instructions/control_flow.py | 110 +++++++++++++++++- whitelist.txt | 6 +- 5 files changed, 227 insertions(+), 14 deletions(-) diff --git a/src/ethereum/prague/vm/eof.py b/src/ethereum/prague/vm/eof.py index e48a8cb42c..79a3a3759f 100644 --- a/src/ethereum/prague/vm/eof.py +++ b/src/ethereum/prague/vm/eof.py @@ -13,7 +13,7 @@ """ from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Set from ethereum.base_types import Uint @@ -328,39 +328,123 @@ def validate_body(code: bytes, eof_header: EOFHeader) -> None: raise InvalidEOF("Stray bytes found after data section") -def validate_code_section(code: bytes) -> None: +def get_valid_jump_destinations(code: bytes) -> Set[int]: """ - Validate a code section of the EOF container. + Get the valid jump destinations for the code. The immediate bytes + of the PUSH, RJUMP, RJUMPI, RJUMPV opcodes are invalid as jump + destinations. Parameters ---------- code : bytes - The code section to validate. + The code section of the EOF container. - Raises - ------ - InvalidEOF - If the code section is invalid. + Returns + ------- + valid_jump_destinations : Set[int] + The valid jump destinations in the code. """ counter = 0 + valid_jump_destinations = set() while counter < len(code): try: opcode = get_opcode(code[counter], EOF.EOF1) except ValueError: raise InvalidEOF("Invalid opcode in code section") + valid_jump_destinations.add(counter) + + counter += 1 if ( opcode.value >= Ops.PUSH1.value and opcode.value <= Ops.PUSH32.value ): push_data_size = opcode.value - Ops.PUSH1.value + 1 - if len(code) < counter + push_data_size + 1: + if len(code) < counter + push_data_size: raise InvalidEOF("Push data missing") + counter += push_data_size + continue - counter += push_data_size + 1 + if opcode in (Ops.RJUMP, Ops.RJUMPI): + if len(code) < counter + 2: + raise InvalidEOF("Relative jump offset missing") + counter += 2 continue - counter += 1 + if opcode == Ops.RJUMPV: + if len(code) < counter + 1: + raise InvalidEOF("max_index missing for RJUMPV") + max_index = code[counter] + num_relative_indices = max_index + 1 + counter += 1 + + for _ in range(num_relative_indices): + if len(code) < counter + 2: + raise InvalidEOF("Relative jump indices missing") + counter += 2 + continue + + return valid_jump_destinations + + +def validate_code_section(code: bytes) -> None: + """ + Validate a code section of the EOF container. + + Parameters + ---------- + code : bytes + The code section to validate. + + Raises + ------ + InvalidEOF + If the code section is invalid. + """ + counter = 0 + valid_jump_destinations = get_valid_jump_destinations(code) + + for counter in valid_jump_destinations: + opcode = get_opcode(code[counter], EOF.EOF1) + + # Make sure the bytes encoding relative offset + # are available + if opcode in (Ops.RJUMP, Ops.RJUMPI): + relative_offset = int.from_bytes( + code[counter + 1 : counter + 3], "big", signed=True + ) + pc_post_instruction = counter + 3 + jump_destination = pc_post_instruction + relative_offset + if ( + jump_destination < 0 + or len(code) < jump_destination + 1 + or jump_destination not in valid_jump_destinations + ): + raise InvalidEOF("Invalid jump destination") + + elif opcode == Ops.RJUMPV: + num_relative_indices = code[counter + 1] + 1 + # pc_post_instruction will be + # counter + 1 <- for normal pc increment to next opcode + # + 1 <- for the 1 byte max_index + # + 2 * num_relative_indices <- for the 2 bytes of each offset + pc_post_instruction = counter + 2 + 2 * num_relative_indices + + index_position = counter + 2 + for _ in range(num_relative_indices): + relative_offset = int.from_bytes( + code[index_position : index_position + 2], + "big", + signed=True, + ) + index_position += 2 + jump_destination = pc_post_instruction + relative_offset + if ( + jump_destination < 0 + or len(code) < jump_destination + 1 + or jump_destination not in valid_jump_destinations + ): + raise InvalidEOF("Invalid jump destination") def validate_eof_code(code: bytes, eof_header: EOFHeader) -> None: diff --git a/src/ethereum/prague/vm/gas.py b/src/ethereum/prague/vm/gas.py index a47fc6e921..84dee991b1 100644 --- a/src/ethereum/prague/vm/gas.py +++ b/src/ethereum/prague/vm/gas.py @@ -66,6 +66,9 @@ GAS_INIT_CODE_WORD_COST = 2 GAS_BLOBHASH_OPCODE = Uint(3) GAS_POINT_EVALUATION = Uint(50000) +GAS_RJUMP = Uint(2) +GAS_RJUMPI = Uint(4) +GAS_RJUMPV = Uint(4) TARGET_BLOB_GAS_PER_BLOCK = U64(393216) GAS_PER_BLOB = Uint(2**17) diff --git a/src/ethereum/prague/vm/instructions/__init__.py b/src/ethereum/prague/vm/instructions/__init__.py index 1eb98b62d4..f1f48a5eb4 100644 --- a/src/ethereum/prague/vm/instructions/__init__.py +++ b/src/ethereum/prague/vm/instructions/__init__.py @@ -202,6 +202,11 @@ class Ops(enum.Enum): LOG3 = 0xA3 LOG4 = 0xA4 + # Static Relative Jumps + RJUMP = 0xE0 + RJUMPI = 0xE1 + RJUMPV = 0xE2 + # System Operations CREATE = 0xF0 CALL = 0xF1 @@ -355,6 +360,9 @@ class Ops(enum.Enum): Ops.LOG2: log_instructions.log2, Ops.LOG3: log_instructions.log3, Ops.LOG4: log_instructions.log4, + Ops.RJUMP: control_flow_instructions.rjump, + Ops.RJUMPI: control_flow_instructions.rjumpi, + Ops.RJUMPV: control_flow_instructions.rjumpv, Ops.CREATE: system_instructions.create, Ops.RETURN: system_instructions.return_, Ops.CALL: system_instructions.call, @@ -367,7 +375,13 @@ class Ops(enum.Enum): } -OPCODES_INVALID_IN_LEGACY = (Ops.INVALID,) +OPCODES_INVALID_IN_LEGACY = ( + Ops.INVALID, + # Relative Jump instructions + Ops.RJUMP, + Ops.RJUMPI, + Ops.RJUMPV, +) OPCODES_INVALID_IN_EOF1 = ( # Control Flow Ops diff --git a/src/ethereum/prague/vm/instructions/control_flow.py b/src/ethereum/prague/vm/instructions/control_flow.py index a967ade61d..57fe94b7a2 100644 --- a/src/ethereum/prague/vm/instructions/control_flow.py +++ b/src/ethereum/prague/vm/instructions/control_flow.py @@ -14,7 +14,16 @@ from ethereum.base_types import U256, Uint -from ...vm.gas import GAS_BASE, GAS_HIGH, GAS_JUMPDEST, GAS_MID, charge_gas +from ...vm.gas import ( + GAS_BASE, + GAS_HIGH, + GAS_JUMPDEST, + GAS_MID, + GAS_RJUMP, + GAS_RJUMPI, + GAS_RJUMPV, + charge_gas, +) from .. import Evm from ..exceptions import InvalidJumpDestError from ..stack import pop, push @@ -169,3 +178,102 @@ def jumpdest(evm: Evm) -> None: # PROGRAM COUNTER evm.pc += 1 + + +def rjump(evm: Evm) -> None: + """ + Jump to a relative offset. + + Parameters + ---------- + evm : + The current EVM frame. + + """ + # STACK + pass + + # GAS + charge_gas(evm, GAS_RJUMP) + + # OPERATION + pass + + # PROGRAM COUNTER + relative_offset = int.from_bytes( + evm.code[evm.pc + 1 : evm.pc + 3], "big", signed=True + ) + # pc + 1 + 2 bytes of relative offset + pc_post_instruction = int(evm.pc) + 3 + evm.pc = Uint(pc_post_instruction + relative_offset) + + +def rjumpi(evm: Evm) -> None: + """ + Jump to a relative offset given a condition. + + Parameters + ---------- + evm : + The current EVM frame. + + """ + # STACK + condition = pop(evm.stack) + + # GAS + charge_gas(evm, GAS_RJUMPI) + + # OPERATION + pass + + # PROGRAM COUNTER + relative_offset = int.from_bytes( + evm.code[evm.pc + 1 : evm.pc + 3], "big", signed=True + ) + # pc + 1 + 2 bytes of relative offset + pc_post_instruction = int(evm.pc) + 3 + if condition == 0: + evm.pc = Uint(pc_post_instruction) + else: + evm.pc = Uint(pc_post_instruction + relative_offset) + + +def rjumpv(evm: Evm) -> None: + """ + Jump to a relative offset via jump table. + + Parameters + ---------- + evm : + The current EVM frame. + + """ + # STACK + case = pop(evm.stack) + + # GAS + charge_gas(evm, GAS_RJUMPV) + + # OPERATION + pass + + # PROGRAM COUNTER + max_index = evm.code[evm.pc + 1] + num_relative_indices = max_index + 1 + # pc_post_instruction will be + # counter + 1 <- for normal pc increment to next opcode + # + 1 <- for the 1 byte max_index + # + 2 * num_relative_indices <- for the 2 bytes of each offset + pc_post_instruction = int(evm.pc) + 2 + 2 * num_relative_indices + + if case > max_index: + evm.pc = Uint(pc_post_instruction) + else: + relative_offset_position = evm.pc + 2 + 2 * case + relative_offset = int.from_bytes( + evm.code[relative_offset_position : relative_offset_position + 2], + "big", + signed=True, + ) + evm.pc = Uint(pc_post_instruction + relative_offset) diff --git a/whitelist.txt b/whitelist.txt index 174c96f182..026eff524e 100644 --- a/whitelist.txt +++ b/whitelist.txt @@ -438,4 +438,8 @@ req predeploy eof -eof1 \ No newline at end of file +eof1 + +RJUMP +RJUMPI +RJUMPV \ No newline at end of file