Skip to content

Commit

Permalink
refactor transaction loader
Browse files Browse the repository at this point in the history
  • Loading branch information
gurukamath committed Mar 5, 2024
1 parent 79f7743 commit a53f57d
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 115 deletions.
118 changes: 8 additions & 110 deletions src/ethereum_spec_tools/evm_tools/loaders/fixture_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, Tuple

from ethereum import rlp
from ethereum.base_types import U64, U256, Bytes0
from ethereum.base_types import U256
from ethereum.crypto.hash import Hash32
from ethereum.utils.hexadecimal import (
hex_to_bytes,
Expand All @@ -21,15 +21,7 @@
)

from .fork_loader import ForkLoad


class UnsupportedTx(Exception):
"""Exception for unsupported transactions"""

def __init__(self, encoded_params: bytes, error_message: str) -> None:
super().__init__(error_message)
self.encoded_params = encoded_params
self.error_message = error_message
from .transaction_loader import TransactionLoad


class BaseLoad(ABC):
Expand All @@ -56,10 +48,12 @@ class Load(BaseLoad):

_network: str
_fork_module: str
fork: ForkLoad

def __init__(self, network: str, fork_name: str):
def __init__(self, network: str, fork_module: str):
self._network = network
self.fork = ForkLoad(fork_name)
self._fork_module = fork_module
self.fork = ForkLoad(fork_module)

def json_to_state(self, raw: Any) -> Any:
"""Converts json state data to a state object"""
Expand All @@ -84,103 +78,6 @@ def json_to_state(self, raw: Any) -> Any:
)
return state

def json_to_access_list(self, raw: Any) -> Any:
"""Converts json access list data to a list of access list entries"""
access_list = []
for sublist in raw:
access_list.append(
(
self.fork.hex_to_address(sublist.get("address")),
[
hex_to_bytes32(key)
for key in sublist.get("storageKeys")
],
)
)
return access_list

def json_to_tx(self, raw: Any) -> Any:
"""Converts json transaction data to a transaction object"""
parameters = [
hex_to_u256(raw.get("nonce")),
hex_to_u256(raw.get("gasLimit")),
Bytes0(b"")
if raw.get("to") == ""
else self.fork.hex_to_address(raw.get("to")),
hex_to_u256(raw.get("value")),
hex_to_bytes(raw.get("data")),
hex_to_u256(
raw.get("y_parity") if "y_parity" in raw else raw.get("v")
),
hex_to_u256(raw.get("r")),
hex_to_u256(raw.get("s")),
]

# Cancun and beyond
if "maxFeePerBlobGas" in raw:
parameters.insert(0, U64(1))
parameters.insert(2, hex_to_u256(raw.get("maxPriorityFeePerGas")))
parameters.insert(3, hex_to_u256(raw.get("maxFeePerGas")))
parameters.insert(
8, self.json_to_access_list(raw.get("accessList"))
)
parameters.insert(9, hex_to_u256(raw.get("maxFeePerBlobGas")))
parameters.insert(
10,
[
hex_to_hash(blob_hash)
for blob_hash in raw.get("blobVersionedHashes")
],
)

try:
return b"\x03" + rlp.encode(
self.fork.BlobTransaction(*parameters)
)
except AttributeError as e:
raise UnsupportedTx(
b"\x03" + rlp.encode(parameters), str(e)
) from e

# London and beyond
if "maxFeePerGas" in raw and "maxPriorityFeePerGas" in raw:
parameters.insert(0, U64(1))
parameters.insert(2, hex_to_u256(raw.get("maxPriorityFeePerGas")))
parameters.insert(3, hex_to_u256(raw.get("maxFeePerGas")))
parameters.insert(
8, self.json_to_access_list(raw.get("accessList"))
)
try:
return b"\x02" + rlp.encode(
self.fork.FeeMarketTransaction(*parameters)
)
except AttributeError as e:
raise UnsupportedTx(
b"\x02" + rlp.encode(parameters), str(e)
) from e

parameters.insert(1, hex_to_u256(raw.get("gasPrice")))
# Access List Transaction
if "accessList" in raw:
parameters.insert(0, U64(1))
parameters.insert(
7, self.json_to_access_list(raw.get("accessList"))
)
try:
return b"\x01" + rlp.encode(
self.fork.AccessListTransaction(*parameters)
)
except AttributeError as e:
raise UnsupportedTx(
b"\x01" + rlp.encode(parameters), str(e)
) from e

# Legacy Transaction
if hasattr(self.fork, "LegacyTransaction"):
return self.fork.LegacyTransaction(*parameters)
else:
return self.fork.Transaction(*parameters)

def json_to_withdrawals(self, raw: Any) -> Any:
"""Converts json withdrawal data to a withdrawal object"""
parameters = [
Expand All @@ -206,7 +103,8 @@ def json_to_block(

header = self.json_to_header(json_block["blockHeader"])
transactions = tuple(
self.json_to_tx(tx) for tx in json_block["transactions"]
TransactionLoad(tx, self.fork).read()
for tx in json_block["transactions"]
)
uncles = tuple(
self.json_to_header(uncle) for uncle in json_block["uncleHeaders"]
Expand Down
4 changes: 2 additions & 2 deletions src/ethereum_spec_tools/evm_tools/loaders/fork_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class ForkLoad:
_fork_module: str
_forks: Any

def __init__(self, fork_name: str):
self._fork_module = fork_name
def __init__(self, fork_module: str):
self._fork_module = fork_module
self._forks = Hardfork.discover()

@property
Expand Down
179 changes: 179 additions & 0 deletions src/ethereum_spec_tools/evm_tools/loaders/transaction_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""
Read transaction data from json file and return the
relevant transaction.
"""

from dataclasses import fields
from typing import Any, List

from ethereum import rlp
from ethereum.base_types import U64, U256, Bytes, Bytes0, Bytes32
from ethereum.utils.hexadecimal import (
hex_to_bytes,
hex_to_bytes32,
hex_to_hash,
hex_to_u256,
)


class UnsupportedTx(Exception):
"""Exception for unsupported transactions"""

def __init__(self, encoded_params: bytes, error_message: str) -> None:
super().__init__(error_message)
self.encoded_params = encoded_params
self.error_message = error_message


class TransactionLoad:
"""
Class for loading transaction data from json file
"""

def __init__(self, raw: Any, fork: Any) -> None:
self.raw = raw
self.fork = fork

def json_to_chain_id(self) -> U64:
"""Get chain ID for the transaction."""
return U64(1)

def json_to_nonce(self) -> U256:
"""Get the nonce for the transaction."""
return hex_to_u256(self.raw.get("nonce"))

def json_to_gas_price(self) -> U256:
"""Get the gas price for the transaction."""
return hex_to_u256(self.raw.get("gasPrice"))

def json_to_gas(self) -> U256:
"""Get the gas limit for the transaction."""
return hex_to_u256(self.raw.get("gasLimit"))

def json_to_to(self) -> Bytes:
"""Get to address for the transaction."""
return (
Bytes0(b"")
if self.raw.get("to") == ""
else self.fork.hex_to_address(self.raw.get("to"))
)

def json_to_value(self) -> U256:
"""Get the value of the transaction."""
return hex_to_u256(self.raw.get("value"))

def json_to_data(self) -> Bytes:
"""Get the data of the transaction."""
return hex_to_bytes(self.raw.get("data"))

def json_to_access_list(self) -> Any:
"""Get the access list of the transaction."""
access_list = []
for sublist in self.raw["accessList"]:
access_list.append(
(
self.fork.hex_to_address(sublist.get("address")),
[
hex_to_bytes32(key)
for key in sublist.get("storageKeys")
],
)
)
return access_list

def json_to_max_priority_fee_per_gas(self) -> U256:
"""Get the max priority fee per gas of the transaction."""
return hex_to_u256(self.raw.get("maxPriorityFeePerGas"))

def json_to_max_fee_per_gas(self) -> U256:
"""Get the max fee per gas of the transaction."""
return hex_to_u256(self.raw.get("maxFeePerGas"))

def json_to_max_fee_per_blob_gas(self) -> U256:
"""
Get the max priority fee per blobgas of the transaction.
"""
return hex_to_u256(self.raw.get("maxFeePerBlobGas"))

def json_to_blob_versioned_hashes(self) -> List[Bytes32]:
"""Get the blob versioned hashes of the transaction."""
return [
hex_to_hash(blob_hash)
for blob_hash in self.raw.get("blobVersionedHashes")
]

def json_to_v(self) -> U256:
"""Get the v value of the transaction."""
return hex_to_u256(
self.raw.get("y_parity")
if "y_parity" in self.raw
else self.raw.get("v")
)

def json_to_y_parity(self) -> U256:
"""Get the y parity of the transaction."""
return self.json_to_v()

def json_to_r(self) -> U256:
"""Get the r value of the transaction"""
return hex_to_u256(self.raw.get("r"))

def json_to_s(self) -> U256:
"""Get the s value of the transaction"""
return hex_to_u256(self.raw.get("s"))

def get_parameters(self, tx_cls: Any) -> List:
"""
Extract all the transaction parameters from the json file
"""
parameters = []
for field in fields(tx_cls):
parameters.append(getattr(self, f"json_to_{field.name}")())
return parameters

def get_legacy_transaction(self) -> Any:
"""Return the approprtiate class for legacy transactions."""
if hasattr(self.fork, "LegacyTransaction"):
return self.fork.LegacyTransaction
else:
return self.fork.Transaction

def read(self) -> Any:
"""Convert json transaction data to a transaction object"""
if "type" in self.raw:
tx_type = self.raw.get("type")
if tx_type == "0x3":
tx_cls = self.fork.BlobTransaction
tx_byte_prefix = b"\x03"
elif tx_type == "0x2":
tx_cls = self.fork.FeeMarketTransaction
tx_byte_prefix = b"\x02"
elif tx_type == "0x1":
tx_cls = self.fork.AccessListTransaction
tx_byte_prefix = b"\x01"
elif tx_type == "0x0":
tx_cls = self.get_legacy_transaction()
tx_byte_prefix = b""
else:
raise ValueError(f"Unknown transaction type: {tx_type}")
else:
if "maxFeePerBlobGas" in self.raw:
tx_cls = self.fork.BlobTransaction
tx_byte_prefix = b"\x03"
elif "maxFeePerGas" in self.raw:
tx_cls = self.fork.FeeMarketTransaction
tx_byte_prefix = b"\x02"
elif "accessList" in self.raw:
tx_cls = self.fork.AccessListTransaction
tx_byte_prefix = b"\x01"
else:
tx_cls = self.get_legacy_transaction()
tx_byte_prefix = b""

parameters = self.get_parameters(tx_cls)
try:
return tx_cls(*parameters)
except Exception as e:
raise UnsupportedTx(
tx_byte_prefix + rlp.encode(parameters), str(e)
) from e
6 changes: 3 additions & 3 deletions src/ethereum_spec_tools/evm_tools/t8n/t8n_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ethereum.crypto.hash import keccak256
from ethereum.utils.hexadecimal import hex_to_bytes, hex_to_u256, hex_to_uint

from ..loaders.fixture_loader import UnsupportedTx
from ..loaders.transaction_loader import TransactionLoad, UnsupportedTx
from ..utils import FatalException, secp256k1_sign


Expand Down Expand Up @@ -187,7 +187,7 @@ def parse_json_tx(self, raw_tx: Any) -> Any:
if "secretKey" in raw_tx and v == r == s == 0:
self.sign_transaction(raw_tx)

tx = t8n.json_to_tx(raw_tx)
tx = TransactionLoad(raw_tx, t8n.fork).read()
self.all_txs.append(tx)

if t8n.fork.is_after_fork("ethereum.berlin"):
Expand Down Expand Up @@ -239,7 +239,7 @@ def sign_transaction(self, json_tx: Any) -> None:
t8n = self.t8n
protected = json_tx.get("protected", True)

tx = t8n.json_to_tx(json_tx)
tx = TransactionLoad(json_tx, t8n.fork).read()

if isinstance(tx, bytes):
tx_decoded = t8n.fork.decode_transaction(tx)
Expand Down

0 comments on commit a53f57d

Please sign in to comment.