diff --git a/EIPS/eip-7495.md b/EIPS/eip-7495.md index d6f0455fa3e6d..9dbf2fbc58fb8 100644 --- a/EIPS/eip-7495.md +++ b/EIPS/eip-7495.md @@ -124,16 +124,21 @@ Merkleization `hash_tree_root(value)` of an object `value` is extended with: ### `Profile[B]` -`Profile[B]` also defines an ordered heterogeneous collection of fields, a subset of fields of a base `StableContainer[N]` `B`. - -The rules for construction are: +`Profile[B]` also defines an ordered heterogeneous collection of fields, a subset of fields of a base `StableContainer` type `B` with the following constraints: - Fields in `Profile[B]` correspond to fields with the same field name in `B`. -- Fields in `Profile[B]` must be in the same order as in `B` -- Fields in the base type `B` MAY be kept `Optional` in `Profile[B]` -- Fields in the base type `B` MAY be required in `Profile[B]` by unwrapping them from `Optional` -- Fields in the base type `B` MAY be omitted in `Profile[B]`, disallowing their presence in the sub-type -- Fields in the base type `B` of type `Optional[T]` with `T` being a nested `StableContainer` MAY have types `Optional[Profile[T]]` (if kept optional), or `Profile[T]` (if it is required) +- Fields in `Profile[B]` follow the same order as in `B`. +- Fields in the base `StableContainer` type `B` are all `Optional`. + - Fields MAY be disallowed in `Profile[B]` by omitting them. + - Fields MAY be kept optional in `Profile[B]` by retaining them as `Optional`. + - Fields MAY be required in `Profile[B]` by unwrapping them from `Optional`. +- All field types in `Profile[B]` MUST be compatible with the corresponding field types in `B`. + - Field types are compatible with themselves. + - `byte` is compatible with `uint8` and vice versa. + - `Bitlist[N]` / `Bitvector[N]` field types are compatible if they share the same capacity `N`. + - `List[T, N]` / `Vector[T, N]` field types are compatible if `T` is compatible and if they also share the same capacity `N`. + - `Container` / `StableContainer[N]` field types are compatible if all inner field types are compatible, if they also share the same field names in the same order, and for `StableContainer[N]` if they also share the same capacity `N`. + - `Profile[X]` field types are compatible with `StableContainer` types compatible with `X`, and are compatible with `Profile[Y]` where `Y` is compatible with `X` if also all inner field types are compatible. Differences solely in optionality do not affect merkleization compatibility. #### Serialization @@ -205,7 +210,7 @@ Typically, the individual `Union` cases share some form of thematic overlap, sha Furthermore, SSZ Union types are currently not used in any final Ethereum specification and do not have a finalized design themselves. The `StableContainer[N]` serializes very similar to current `Union[T, U, V]` proposals, with the difference being a `Bitvector[N]` as a prefix instead of a selector byte. This means that the serialized byte lengths are comparable. -### Why not a `Container` full of `Optional[T]`? +### Why not model `Optional[T]` as an SSZ type? If `Optional[T]` is modeled as an SSZ type, each individual field introduces serialization and merkleization overhead. As an `Optional[T]` would be required to be ["variable-size"](https://github.com/ethereum/consensus-specs/blob/4afe39822c9ad9747e0f5635cca117c18441ec1b/ssz/simple-serialize.md#variable-size-and-fixed-size), lots of additional offset bytes would have to be used in the serialization. For merkleization, each individual `Optional[T]` would require mixing in a bit to indicate presence or absence of the value. diff --git a/assets/eip-7495/stable_container.py b/assets/eip-7495/stable_container.py index c265406b64ee3..32392a814882e 100644 --- a/assets/eip-7495/stable_container.py +++ b/assets/eip-7495/stable_container.py @@ -1,27 +1,50 @@ import io -from typing import Any, BinaryIO, Dict, List as PyList, Optional, Tuple, TypeVar, Type, Union as PyUnion, \ +from typing import Any, BinaryIO, Dict, List as PyList, Optional, Tuple, \ + TypeVar, Type, Union as PyUnion, \ get_args, get_origin from textwrap import indent -from remerkleable.bitfields import Bitvector -from remerkleable.complex import ComplexView, Container, FieldOffset, \ +from remerkleable.basic import boolean, uint8, uint16, uint32, uint64, uint128, uint256 +from remerkleable.bitfields import Bitlist, Bitvector +from remerkleable.byte_arrays import ByteList, ByteVector +from remerkleable.complex import ComplexView, Container, FieldOffset, List, Vector, \ decode_offset, encode_offset -from remerkleable.core import View, ViewHook, OFFSET_BYTE_LENGTH +from remerkleable.core import View, ViewHook, ViewMeta, OFFSET_BYTE_LENGTH from remerkleable.tree import Gindex, NavigationError, Node, PairNode, \ get_depth, subtree_fill_to_contents, zero_node, \ RIGHT_GINDEX -N = TypeVar('N') -B = TypeVar('B', bound="ComplexView") -S = TypeVar('S', bound="ComplexView") +N = TypeVar('N', bound=int) +SV = TypeVar('SV', bound='ComplexView') +BV = TypeVar('BV', bound='ComplexView') -def all_fields(cls) -> Dict[str, Tuple[Type[View], bool]]: - fields = {} - for k, v in cls.__annotations__.items(): - fopt = get_origin(v) == PyUnion and type(None) in get_args(v) - ftyp = get_args(v)[0] if fopt else v - fields[k] = (ftyp, fopt) - return fields +def stable_get(self, findex, ftyp, n): + if not self.active_fields().get(findex): + return None + data = self.get_backing().get_left() + fnode = data.getter(2**get_depth(n) + findex) + return ftyp.view_from_backing(fnode) + + +def stable_set(self, findex, ftyp, n, value): + next_backing = self.get_backing() + + active_fields = self.active_fields() + active_fields.set(findex, value is not None) + next_backing = next_backing.rebind_right(active_fields.get_backing()) + + if value is not None: + if isinstance(value, ftyp): + fnode = value.get_backing() + else: + fnode = ftyp.coerce_view(value).get_backing() + else: + fnode = zero_node(0) + data = next_backing.get_left() + next_data = data.setter(2**get_depth(n) + findex)(fnode) + next_backing = next_backing.rebind_left(next_data) + + self.set_backing(next_backing) def field_val_repr(self, fkey: str, ftyp: Type[View], fopt: bool) -> str: @@ -35,37 +58,25 @@ def field_val_repr(self, fkey: str, ftyp: Type[View], fopt: bool) -> str: field_repr = field_repr[:i+1] + indent(field_repr[i+1:], ' ' * len(field_start)) return field_start + field_repr except NavigationError: - return f"{field_start} *omitted*" - - -def repr(self) -> str: - return f"{self.__class__.type_repr()}:\n" + '\n'.join( - indent(field_val_repr(self, fkey, ftyp, fopt), ' ') - for fkey, (ftyp, fopt) in self.__class__.fields().items()) + return f'{field_start} *omitted*' class StableContainer(ComplexView): - _field_indices: Dict[str, Tuple[int, Type[View], bool]] - __slots__ = '_field_indices' + __slots__ = '_field_indices', 'N' + _field_indices: Dict[str, Tuple[int, Type[View]]] + N: int def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None, **kwargs): if backing is not None: if len(kwargs) != 0: - raise Exception("cannot have both a backing and elements to init fields") + raise Exception('Cannot have both a backing and elements to init fields') return super().__new__(cls, backing=backing, hook=hook, **kwargs) - for fkey, (ftyp, fopt) in cls.fields().items(): - if fkey not in kwargs: - if not fopt: - raise AttributeError(f"Field '{fkey}' is required in {cls}") - kwargs[fkey] = None - input_nodes = [] active_fields = Bitvector[cls.N]() - for findex, (fkey, (ftyp, fopt)) in enumerate(cls.fields().items()): + for fkey, (findex, ftyp) in cls._field_indices.items(): fnode: Node - assert fkey in kwargs - finput = kwargs.pop(fkey) + finput = kwargs.pop(fkey) if fkey in kwargs else None if finput is None: fnode = zero_node(0) active_fields.set(findex, False) @@ -76,35 +87,59 @@ def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None fnode = ftyp.coerce_view(finput).get_backing() active_fields.set(findex, True) input_nodes.append(fnode) - if len(kwargs) > 0: - raise AttributeError(f'The field names [{"".join(kwargs.keys())}] are not defined in {cls}') + raise AttributeError(f'Fields [{"".join(kwargs.keys())}] unknown in `{cls.__name__}`') backing = PairNode( left=subtree_fill_to_contents(input_nodes, get_depth(cls.N)), - right=active_fields.get_backing()) + right=active_fields.get_backing(), + ) return super().__new__(cls, backing=backing, hook=hook, **kwargs) - def __init_subclass__(cls, *args, **kwargs): - super().__init_subclass__(*args, **kwargs) - cls._field_indices = { - fkey: (i, ftyp, fopt) - for i, (fkey, (ftyp, fopt)) in enumerate(cls.fields().items()) - } - - def __class_getitem__(cls, n) -> Type["StableContainer"]: + def __init_subclass__(cls, **kwargs): + if 'n' not in kwargs: + raise TypeError(f'Missing capacity: `{cls.__name__}(StableContainer)`') + n = kwargs.pop('n') + if not isinstance(n, int): + raise TypeError(f'Invalid capacity: `StableContainer[{n}]`') if n <= 0: - raise Exception(f"invalid stablecontainer capacity: {n}") - - class StableContainerView(StableContainer): - N = n + raise TypeError(f'Unsupported capacity: `StableContainer[{n}]`') + cls.N = n + + def __class_getitem__(cls, n: int) -> Type['StableContainer']: + class StableContainerMeta(ViewMeta): + def __new__(cls, name, bases, dct): + return super().__new__(cls, name, bases, dct, n=n) + + class StableContainerView(StableContainer, metaclass=StableContainerMeta): + def __init_subclass__(cls, **kwargs): + if 'N' in cls.__dict__: + raise TypeError(f'Cannot override `N` inside `{cls.__name__}`') + cls._field_indices = {} + for findex, (fkey, t) in enumerate(cls.__annotations__.items()): + if ( + get_origin(t) != PyUnion + or len(get_args(t)) != 2 + or type(None) not in get_args(t) + ): + raise TypeError( + f'`StableContainer` fields must be `Optional[T]` ' + f'but `{cls.__name__}.{fkey}` has type `{t.__name__}`' + ) + ftyp = get_args(t)[0] if get_args(t)[0] is not type(None) else get_args(t)[1] + cls._field_indices[fkey] = (findex, ftyp) + if len(cls._field_indices) > cls.N: + raise TypeError( + f'`{cls.__name__}` is `StableContainer[{cls.N}]` ' + f'but contains {len(cls._field_indices)} fields' + ) StableContainerView.__name__ = StableContainerView.type_repr() return StableContainerView @classmethod - def fields(cls) -> Dict[str, Tuple[Type[View], bool]]: - return all_fields(cls) + def fields(cls) -> Dict[str, Type[View]]: + return { fkey: ftyp for fkey, (_, ftyp) in cls._field_indices.items() } @classmethod def is_fixed_byte_length(cls) -> bool: @@ -112,31 +147,36 @@ def is_fixed_byte_length(cls) -> bool: @classmethod def min_byte_length(cls) -> int: - total = Bitvector[cls.N].type_byte_length() - for _, (ftyp, fopt) in cls.fields().items(): - if fopt: - continue - if not ftyp.is_fixed_byte_length(): - total += OFFSET_BYTE_LENGTH - total += ftyp.min_byte_length() - return total + return Bitvector[cls.N].type_byte_length() @classmethod def max_byte_length(cls) -> int: total = Bitvector[cls.N].type_byte_length() - for _, (ftyp, _) in cls.fields().items(): + for (_, ftyp) in cls._field_indices.values(): if not ftyp.is_fixed_byte_length(): total += OFFSET_BYTE_LENGTH total += ftyp.max_byte_length() return total + @classmethod + def is_packed(cls) -> bool: + return False + + @classmethod + def tree_depth(cls) -> int: + return get_depth(cls.N) + + @classmethod + def item_elem_cls(cls, i: int) -> Type[View]: + return list(cls._field_indices.values())[i] + def active_fields(self) -> Bitvector: active_fields_node = super().get_backing().get_right() return Bitvector[self.__class__.N].view_from_backing(active_fields_node) def __getattribute__(self, item): if item == 'N': - raise AttributeError(f"use .__class__.{item} to access {item}") + raise AttributeError(f'Use `.__class__.{item}` to access `{item}`') return object.__getattribute__(self, item) def __getattr__(self, item): @@ -144,70 +184,49 @@ def __getattr__(self, item): return super().__getattribute__(item) else: try: - (findex, ftyp, fopt) = self.__class__._field_indices[item] + (findex, ftyp) = self.__class__._field_indices[item] except KeyError: - raise AttributeError(f"unknown attribute {item}") + raise AttributeError(f'Unknown field `{item}`') - if not self.active_fields().get(findex): - assert fopt - return None - - data = super().get_backing().get_left() - fnode = data.getter(2**get_depth(self.__class__.N) + findex) - return ftyp.view_from_backing(fnode) + return stable_get(self, findex, ftyp, self.__class__.N) def __setattr__(self, key, value): if key[0] == '_': super().__setattr__(key, value) else: try: - (findex, ftyp, fopt) = self.__class__._field_indices[key] + (findex, ftyp) = self.__class__._field_indices[key] except KeyError: - raise AttributeError(f"unknown attribute {key}") - - next_backing = self.get_backing() + raise AttributeError(f'Unknown field `{key}`') - assert value is not None or fopt - active_fields = self.active_fields() - active_fields.set(findex, value is not None) - next_backing = next_backing.rebind_right(active_fields.get_backing()) - - if value is not None: - if isinstance(value, ftyp): - fnode = value.get_backing() - else: - fnode = ftyp.coerce_view(value).get_backing() - else: - fnode = zero_node(0) - data = next_backing.get_left() - next_data = data.setter(2**get_depth(self.__class__.N) + findex)(fnode) - next_backing = next_backing.rebind_left(next_data) - - self.set_backing(next_backing) + stable_set(self, findex, ftyp, self.__class__.N, value) def __repr__(self): - return repr(self) + return f'{self.__class__.type_repr()}:\n' + '\n'.join( + indent(field_val_repr(self, fkey, ftyp, fopt=True), ' ') + for fkey, (_, ftyp) in self.__class__._field_indices.items()) @classmethod def type_repr(cls) -> str: - return f"StableContainer[{cls.N}]" + return f'StableContainer[{cls.N}]' @classmethod - def deserialize(cls: Type[S], stream: BinaryIO, scope: int) -> S: + def deserialize(cls: Type[SV], stream: BinaryIO, scope: int) -> SV: num_prefix_bytes = Bitvector[cls.N].type_byte_length() if scope < num_prefix_bytes: - raise ValueError("scope too small, cannot read StableContainer active fields") + raise ValueError(f'Scope too small for `StableContainer[{cls.N}]` active fields') active_fields = Bitvector[cls.N].deserialize(stream, num_prefix_bytes) scope = scope - num_prefix_bytes - max_findex = 0 - field_values: Dict[str, Optional[View]] = {} + for findex in range(len(cls._field_indices), cls.N): + if active_fields.get(findex): + raise Exception(f'Unknown field index {findex}') + + field_values: Dict[str, View] = {} dyn_fields: PyList[FieldOffset] = [] fixed_size = 0 - for findex, (fkey, (ftyp, _)) in enumerate(cls.fields().items()): - max_findex = findex + for fkey, (findex, ftyp) in cls._field_indices.items(): if not active_fields.get(findex): - field_values[fkey] = None continue if ftyp.is_fixed_byte_length(): fsize = ftyp.type_byte_length() @@ -219,37 +238,41 @@ def deserialize(cls: Type[S], stream: BinaryIO, scope: int) -> S: fixed_size += OFFSET_BYTE_LENGTH if len(dyn_fields) > 0: if dyn_fields[0].offset < fixed_size: - raise Exception(f"first offset {dyn_fields[0].offset} is " - f"smaller than expected fixed size {fixed_size}") + raise Exception(f'First offset {dyn_fields[0].offset} is ' + f'smaller than expected fixed size {fixed_size}') for i, (fkey, ftyp, foffset) in enumerate(dyn_fields): next_offset = dyn_fields[i + 1].offset if i + 1 < len(dyn_fields) else scope if foffset > next_offset: - raise Exception(f"offset {i} is invalid: {foffset} " - f"larger than next offset {next_offset}") + raise Exception(f'Offset {i} is invalid: {foffset} ' + f'larger than next offset {next_offset}') fsize = next_offset - foffset f_min_size, f_max_size = ftyp.min_byte_length(), ftyp.max_byte_length() if not (f_min_size <= fsize <= f_max_size): - raise Exception(f"offset {i} is invalid, size out of bounds: " - f"{foffset}, next {next_offset}, implied size: {fsize}, " - f"size bounds: [{f_min_size}, {f_max_size}]") + raise Exception(f'Offset {i} is invalid, size out of bounds: ' + f'{foffset}, next {next_offset}, implied size: {fsize}, ' + f'size bounds: [{f_min_size}, {f_max_size}]') field_values[fkey] = ftyp.deserialize(stream, fsize) - for findex in range(max_findex + 1, cls.N): - if active_fields.get(findex): - raise Exception(f"unknown field index {findex}") return cls(**field_values) # type: ignore def serialize(self, stream: BinaryIO) -> int: active_fields = self.active_fields() num_prefix_bytes = active_fields.serialize(stream) - num_data_bytes = sum( - ftyp.type_byte_length() if ftyp.is_fixed_byte_length() else OFFSET_BYTE_LENGTH - for findex, (_, (ftyp, _)) in enumerate(self.__class__.fields().items()) - if active_fields.get(findex)) + num_data_bytes = 0 + has_dyn_fields = False + for (findex, ftyp) in self.__class__._field_indices.values(): + if not active_fields.get(findex): + continue + if ftyp.is_fixed_byte_length(): + num_data_bytes += ftyp.type_byte_length() + else: + num_data_bytes += OFFSET_BYTE_LENGTH + has_dyn_fields = True - temp_dyn_stream = io.BytesIO() + if has_dyn_fields: + temp_dyn_stream = io.BytesIO() data = super().get_backing().get_left() - for findex, (_, (ftyp, _)) in enumerate(self.__class__.fields().items()): + for (findex, ftyp) in self.__class__._field_indices.values(): if not active_fields.get(findex): continue fnode = data.getter(2**get_depth(self.__class__.N) + findex) @@ -259,8 +282,9 @@ def serialize(self, stream: BinaryIO) -> int: else: encode_offset(stream, num_data_bytes) num_data_bytes += v.serialize(temp_dyn_stream) # type: ignore - temp_dyn_stream.seek(0) - stream.write(temp_dyn_stream.read(num_data_bytes)) + if has_dyn_fields: + temp_dyn_stream.seek(0) + stream.write(temp_dyn_stream.read()) return num_prefix_bytes + num_data_bytes @@ -268,69 +292,216 @@ def serialize(self, stream: BinaryIO) -> int: def navigate_type(cls, key: Any) -> Type[View]: if key == '__active_fields__': return Bitvector[cls.N] - (_, ftyp, fopt) = cls._field_indices[key] - if fopt: - return Optional[ftyp] - return ftyp + (_, ftyp) = cls._field_indices[key] + return Optional[ftyp] @classmethod def key_to_static_gindex(cls, key: Any) -> Gindex: if key == '__active_fields__': return RIGHT_GINDEX - (findex, _, _) = cls._field_indices[key] + (findex, _) = cls._field_indices[key] return 2**get_depth(cls.N) * 2 + findex class Profile(ComplexView): + __slots__ = '_field_indices', '_o', 'B' + _field_indices: Dict[str, Tuple[int, Type[View], bool]] _o: int + B: Type[StableContainer] def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None, **kwargs): if backing is not None: if len(kwargs) != 0: - raise Exception("cannot have both a backing and elements to init fields") + raise Exception('Cannot have both a backing and elements to init fields') return super().__new__(cls, backing=backing, hook=hook, **kwargs) - extra_kwargs = kwargs.copy() - for fkey, (ftyp, fopt) in cls.fields().items(): - if fkey in extra_kwargs: - extra_kwargs.pop(fkey) + extra_kw = kwargs.copy() + for fkey, (_, _, fopt) in cls._field_indices.items(): + if fkey in extra_kw: + extra_kw.pop(fkey) elif not fopt: - raise AttributeError(f"Field '{fkey}' is required in {cls}") + raise AttributeError(f'Field `{fkey}` is required in {cls.__name__}') else: pass - if len(extra_kwargs) > 0: - raise AttributeError(f'The field names [{"".join(extra_kwargs.keys())}] are not defined in {cls}') + if len(extra_kw) > 0: + raise AttributeError(f'Fields [{"".join(extra_kw.keys())}] unknown in `{cls.__name__}`') value = cls.B(backing, hook, **kwargs) return cls(backing=value.get_backing()) - def __init_subclass__(cls, *args, **kwargs): - super().__init_subclass__(*args, **kwargs) - cls._o = 0 - for _, (_, fopt) in cls.fields().items(): - if fopt: - cls._o += 1 - assert cls._o == 0 or issubclass(cls.B, StableContainer) - - def __class_getitem__(cls, b) -> Type["Profile"]: - if not issubclass(b, StableContainer) and not issubclass(b, Container): - raise Exception(f"invalid Profile base: {b}") + def __init_subclass__(cls, **kwargs): + if 'b' not in kwargs: + raise TypeError(f'Missing base type: `{cls.__name__}(Profile)`') + b = kwargs.pop('b') + if not issubclass(b, StableContainer): + raise TypeError(f'Invalid base type: `Profile[{b.__name__}]`') + cls.B = b + + def __class_getitem__(cls, b) -> Type['Profile']: + def has_compatible_merkleization(ftyp, ftyp_base) -> bool: + if ftyp == ftyp_base: + return True + if issubclass(ftyp, boolean): + return issubclass(ftyp_base, boolean) + if issubclass(ftyp, uint8): + return issubclass(ftyp_base, uint8) + if issubclass(ftyp, uint16): + return issubclass(ftyp_base, uint16) + if issubclass(ftyp, uint32): + return issubclass(ftyp_base, uint32) + if issubclass(ftyp, uint64): + return issubclass(ftyp_base, uint64) + if issubclass(ftyp, uint128): + return issubclass(ftyp_base, uint128) + if issubclass(ftyp, uint256): + return issubclass(ftyp_base, uint256) + if issubclass(ftyp, Bitlist): + return ( + issubclass(ftyp_base, Bitlist) + and ftyp.limit() == ftyp_base.limit() + ) + if issubclass(ftyp, Bitvector): + return ( + issubclass(ftyp_base, Bitvector) + and ftyp.vector_length() == ftyp_base.vector_length() + ) + if issubclass(ftyp, ByteList): + if issubclass(ftyp_base, ByteList): + return ftyp.limit() == ftyp_base.limit() + return ( + issubclass(ftyp_base, List) + and ftyp.limit() == ftyp_base.limit() + and issubclass(ftyp_base.element_cls(), uint8) + ) + if issubclass(ftyp, ByteVector): + if issubclass(ftyp_base, ByteVector): + return ftyp.vector_length() == ftyp_base.vector_length() + return ( + issubclass(ftyp_base, Vector) + and ftyp.vector_length() == ftyp_base.vector_length() + and issubclass(ftyp_base.element_cls(), uint8) + ) + if issubclass(ftyp, List): + if issubclass(ftyp_base, ByteList): + return ( + ftyp.limit() == ftyp_base.limit() + and issubclass(ftyp.element_cls(), uint8) + ) + return ( + issubclass(ftyp_base, List) + and ftyp.limit() == ftyp_base.limit() + and has_compatible_merkleization(ftyp.element_cls(), ftyp_base.element_cls()) + ) + if issubclass(ftyp, Vector): + if issubclass(ftyp_base, ByteVector): + return ( + ftyp.vector_length() == ftyp_base.vector_length() + and issubclass(ftyp.element_cls(), uint8) + ) + return ( + issubclass(ftyp_base, Vector) + and ftyp.vector_length() == ftyp_base.vector_length() + and has_compatible_merkleization(ftyp.element_cls(), ftyp_base.element_cls()) + ) + if issubclass(ftyp, Container): + if not issubclass(ftyp_base, Container): + return False + fields = ftyp.fields() + fields_base = ftyp_base.fields() + if len(fields) != len(fields_base): + return False + for (fkey, t), (fkey_b, t_b) in zip(fields.items(), fields_base.items()): + if fkey != fkey_b: + return False + if not has_compatible_merkleization(t, t_b): + return False + return True + if issubclass(ftyp, StableContainer): + if not issubclass(ftyp_base, StableContainer): + return False + if ftyp.N != ftyp_base.N: + return False + fields = ftyp.fields() + fields_base = ftyp_base.fields() + if len(fields) != len(fields_base): + return False + for (fkey, t), (fkey_b, t_b) in zip(fields.items(), fields_base.items()): + if fkey != fkey_b: + return False + if not has_compatible_merkleization(t, t_b): + return False + return True + if issubclass(ftyp, Profile): + if issubclass(ftyp_base, StableContainer): + return has_compatible_merkleization(ftyp.B, ftyp_base) + if not issubclass(ftyp_base, Profile): + return False + if not has_compatible_merkleization(ftyp.B, ftyp_base.B): + return False + fields = ftyp.fields() + fields_base = ftyp_base.fields() + if len(fields) != len(fields_base): + return False + for (fkey, (t, _)), (fkey_b, (t_b, _)) in zip(fields.items(), fields_base.items()): + if fkey != fkey_b: + return False + if not has_compatible_merkleization(t, t_b): + return False + return True + return False - class ProfileView(Profile): - B = b + class ProfileMeta(ViewMeta): + def __new__(cls, name, bases, dct): + return super().__new__(cls, name, bases, dct, b=b) + + class ProfileView(Profile, metaclass=ProfileMeta): + def __init_subclass__(cls, **kwargs): + if 'B' in cls.__dict__: + raise TypeError(f'Cannot override `B` inside `{cls.__name__}`') + cls._field_indices = {} + cls._o = 0 + last_findex = -1 + for (fkey, t) in cls.__annotations__.items(): + if fkey not in cls.B._field_indices: + raise TypeError( + f'`{cls.__name__}` fields must exist in the base type ' + f'but `{fkey}` is not defined in `{cls.B.__name__}`' + ) + (findex, ftyp) = cls.B._field_indices[fkey] + if findex <= last_findex: + raise TypeError( + f'`{cls.__name__}` fields must have the same order as in the base type ' + f'but `{fkey}` is defined earlier than in `{cls.B.__name__}`' + ) + last_findex = findex + fopt = ( + get_origin(t) == PyUnion + and len(get_args(t)) == 2 + and type(None) in get_args(t) + ) + if fopt: + t = get_args(t)[0] if get_args(t)[0] is not type(None) else get_args(t)[1] + if not has_compatible_merkleization(t, ftyp): + raise TypeError( + f'`{cls.__name__}.{fkey}` has type `{t.__name__}`, incompatible ' + f'with base field `{cls.B.__name__}.{fkey}` of type `{ftyp.__name__}`' + ) + cls._field_indices[fkey] = (findex, t, fopt) + if fopt: + cls._o += 1 ProfileView.__name__ = ProfileView.type_repr() return ProfileView @classmethod def fields(cls) -> Dict[str, Tuple[Type[View], bool]]: - return all_fields(cls) + return { fkey: (ftyp, fopt) for fkey, (_, ftyp, fopt) in cls._field_indices.items() } @classmethod def is_fixed_byte_length(cls) -> bool: if cls._o > 0: return False - for _, (ftyp, _) in cls.fields().items(): + for (_, ftyp, _) in cls._field_indices.values(): if not ftyp.is_fixed_byte_length(): return False return True @@ -340,12 +511,12 @@ def type_byte_length(cls) -> int: if cls.is_fixed_byte_length(): return cls.min_byte_length() else: - raise Exception("dynamic length Profile does not have a fixed byte length") + raise Exception(f'Dynamic length `Profile` does not have a fixed byte length') @classmethod def min_byte_length(cls) -> int: total = Bitvector[cls._o].type_byte_length() if cls._o > 0 else 0 - for _, (ftyp, fopt) in cls.fields().items(): + for (_, ftyp, fopt) in cls._field_indices.values(): if fopt: continue if not ftyp.is_fixed_byte_length(): @@ -356,33 +527,43 @@ def min_byte_length(cls) -> int: @classmethod def max_byte_length(cls) -> int: total = Bitvector[cls._o].type_byte_length() if cls._o > 0 else 0 - for _, (ftyp, _) in cls.fields().items(): + for (_, ftyp, _) in cls._field_indices.values(): if not ftyp.is_fixed_byte_length(): total += OFFSET_BYTE_LENGTH total += ftyp.max_byte_length() return total + @classmethod + def is_packed(cls) -> bool: + return False + + @classmethod + def tree_depth(cls) -> int: + return cls.B.tree_depth() + + @classmethod + def item_elem_cls(cls, i: int) -> Type[View]: + return cls.B.item_elem_cls(i) + def active_fields(self) -> Bitvector: - assert issubclass(self.__class__.B, StableContainer) active_fields_node = super().get_backing().get_right() return Bitvector[self.__class__.B.N].view_from_backing(active_fields_node) def optional_fields(self) -> Bitvector: - assert issubclass(self.__class__.B, StableContainer) - assert self.__class__._o > 0 + if self.__class__._o == 0: + raise Exception(f'`{self.__class__.__name__}` does not have any `Optional[T]` fields') active_fields = self.active_fields() optional_fields = Bitvector[self.__class__._o]() oindex = 0 - for fkey, (_, fopt) in self.__class__.fields().items(): + for (findex, _, fopt) in self.__class__._field_indices.values(): if fopt: - (findex, _, _) = self.__class__.B._field_indices[fkey] optional_fields.set(oindex, active_fields.get(findex)) oindex += 1 return optional_fields def __getattribute__(self, item): if item == 'B': - raise AttributeError(f"use .__class__.{item} to access {item}") + raise AttributeError(f'Use `.__class__.{item}` to access `{item}`') return object.__getattribute__(self, item) def __getattr__(self, item): @@ -390,75 +571,42 @@ def __getattr__(self, item): return super().__getattribute__(item) else: try: - (ftyp, fopt) = self.__class__.fields()[item] - except KeyError: - raise AttributeError(f"unknown attribute {item}") - try: - (findex, _, _) = self.__class__.B._field_indices[item] + (findex, ftyp, fopt) = self.__class__._field_indices[item] except KeyError: - raise AttributeError(f"unknown attribute {item} in base") + raise AttributeError(f'Unknown field `{item}`') - if not issubclass(self.__class__.B, StableContainer): - return super().get(findex) - - if not self.active_fields().get(findex): - assert fopt - return None - - data = super().get_backing().get_left() - fnode = data.getter(2**get_depth(self.__class__.B.N) + findex) - return ftyp.view_from_backing(fnode) + value = stable_get(self, findex, ftyp, self.__class__.B.N) + assert value is not None or fopt + return value def __setattr__(self, key, value): if key[0] == '_': super().__setattr__(key, value) else: try: - (ftyp, fopt) = self.__class__.fields()[key] - except KeyError: - raise AttributeError(f"unknown attribute {key}") - try: - (findex, _, _) = self.__class__.B._field_indices[key] + (findex, ftyp, fopt) = self.__class__._field_indices[key] except KeyError: - raise AttributeError(f"unknown attribute {key} in base") - - if not issubclass(self.__class__.B, StableContainer): - super().set(findex, value) - return - - next_backing = self.get_backing() - - assert value is not None or fopt - active_fields = self.active_fields() - active_fields.set(findex, value is not None) - next_backing = next_backing.rebind_right(active_fields.get_backing()) - - if value is not None: - if isinstance(value, ftyp): - fnode = value.get_backing() - else: - fnode = ftyp.coerce_view(value).get_backing() - else: - fnode = zero_node(0) - data = next_backing.get_left() - next_data = data.setter(2**get_depth(self.__class__.B.N) + findex)(fnode) - next_backing = next_backing.rebind_left(next_data) + raise AttributeError(f'Unknown field `{key}`') - self.set_backing(next_backing) + if value is None and not fopt: + raise ValueError(f'Field `{key}` is required and cannot be set to `None`') + stable_set(self, findex, ftyp, self.__class__.B.N, value) def __repr__(self): - return repr(self) + return f'{self.__class__.type_repr()}:\n' + '\n'.join( + indent(field_val_repr(self, fkey, ftyp, fopt), ' ') + for fkey, (_, ftyp, fopt) in self.__class__._field_indices.items()) @classmethod def type_repr(cls) -> str: - return f"Profile[{cls.B.__name__}]" + return f'Profile[{cls.B.__name__}]' @classmethod - def deserialize(cls: Type[B], stream: BinaryIO, scope: int) -> B: + def deserialize(cls: Type[BV], stream: BinaryIO, scope: int) -> BV: if cls._o > 0: num_prefix_bytes = Bitvector[cls._o].type_byte_length() if scope < num_prefix_bytes: - raise ValueError("scope too small, cannot read Profile optional fields") + raise ValueError(f'Scope too small for `Profile[{cls.B.__name__}]` optional fields') optional_fields = Bitvector[cls._o].deserialize(stream, num_prefix_bytes) scope = scope - num_prefix_bytes @@ -466,11 +614,11 @@ def deserialize(cls: Type[B], stream: BinaryIO, scope: int) -> B: dyn_fields: PyList[FieldOffset] = [] fixed_size = 0 oindex = 0 - for fkey, (ftyp, fopt) in cls.fields().items(): + for fkey, (_, ftyp, fopt) in cls._field_indices.items(): if fopt: - have_field = optional_fields.get(oindex) + has_field = optional_fields.get(oindex) oindex += 1 - if not have_field: + if not has_field: field_values[fkey] = None continue if ftyp.is_fixed_byte_length(): @@ -484,21 +632,20 @@ def deserialize(cls: Type[B], stream: BinaryIO, scope: int) -> B: assert oindex == cls._o if len(dyn_fields) > 0: if dyn_fields[0].offset < fixed_size: - raise Exception(f"first offset {dyn_fields[0].offset} is " - f"smaller than expected fixed size {fixed_size}") + raise Exception(f'First offset {dyn_fields[0].offset} is ' + f'smaller than expected fixed size {fixed_size}') for i, (fkey, ftyp, foffset) in enumerate(dyn_fields): next_offset = dyn_fields[i + 1].offset if i + 1 < len(dyn_fields) else scope if foffset > next_offset: - raise Exception(f"offset {i} is invalid: {foffset} " - f"larger than next offset {next_offset}") + raise Exception(f'Offset {i} is invalid: {foffset} ' + f'larger than next offset {next_offset}') fsize = next_offset - foffset f_min_size, f_max_size = ftyp.min_byte_length(), ftyp.max_byte_length() if not (f_min_size <= fsize <= f_max_size): - raise Exception(f"offset {i} is invalid, size out of bounds: " - f"{foffset}, next {next_offset}, implied size: {fsize}, " - f"size bounds: [{f_min_size}, {f_max_size}]") + raise Exception(f'Offset {i} is invalid, size out of bounds: ' + f'{foffset}, next {next_offset}, implied size: {fsize}, ' + f'size bounds: [{f_min_size}, {f_max_size}]') field_values[fkey] = ftyp.deserialize(stream, fsize) - return cls(**field_values) # type: ignore def serialize(self, stream: BinaryIO) -> int: @@ -509,44 +656,39 @@ def serialize(self, stream: BinaryIO) -> int: num_prefix_bytes = 0 num_data_bytes = 0 + has_dyn_fields = False oindex = 0 - for _, (ftyp, fopt) in self.__class__.fields().items(): + for (_, ftyp, fopt) in self.__class__._field_indices.values(): if fopt: - have_field = optional_fields.get(oindex) + has_field = optional_fields.get(oindex) oindex += 1 - if not have_field: + if not has_field: continue if ftyp.is_fixed_byte_length(): num_data_bytes += ftyp.type_byte_length() else: num_data_bytes += OFFSET_BYTE_LENGTH + has_dyn_fields = True assert oindex == self.__class__._o - temp_dyn_stream = io.BytesIO() - if issubclass(self.__class__.B, StableContainer): - data = super().get_backing().get_left() - active_fields = self.active_fields() - n = self.__class__.B.N - else: - data = super().get_backing() - n = len(self.__class__.B.fields()) - for fkey, (ftyp, _) in self.__class__.fields().items(): - if issubclass(self.__class__.B, StableContainer): - (findex, _, _) = self.__class__.B._field_indices[fkey] - if not active_fields.get(findex): - continue - fnode = data.getter(2**get_depth(n) + findex) - else: - findex = self.__class__.B._field_indices[fkey] - fnode = data.getter(2**get_depth(n) + findex) + if has_dyn_fields: + temp_dyn_stream = io.BytesIO() + data = super().get_backing().get_left() + active_fields = self.active_fields() + n = self.__class__.B.N + for (findex, ftyp, _) in self.__class__._field_indices.values(): + if not active_fields.get(findex): + continue + fnode = data.getter(2**get_depth(n) + findex) v = ftyp.view_from_backing(fnode) if ftyp.is_fixed_byte_length(): v.serialize(stream) else: encode_offset(stream, num_data_bytes) num_data_bytes += v.serialize(temp_dyn_stream) # type: ignore - temp_dyn_stream.seek(0) - stream.write(temp_dyn_stream.read(num_data_bytes)) + if has_dyn_fields: + temp_dyn_stream.seek(0) + stream.write(temp_dyn_stream.read(num_data_bytes)) return num_prefix_bytes + num_data_bytes @@ -554,58 +696,12 @@ def serialize(self, stream: BinaryIO) -> int: def navigate_type(cls, key: Any) -> Type[View]: if key == '__active_fields__': return Bitvector[cls.B.N] - (ftyp, fopt) = cls.fields()[key] - if fopt: - return Optional[ftyp] - return ftyp + (_, ftyp, fopt) = cls._field_indices[key] + return Optional[ftyp] if fopt else ftyp @classmethod def key_to_static_gindex(cls, key: Any) -> Gindex: if key == '__active_fields__': return RIGHT_GINDEX - (_, _) = cls.fields()[key] - if issubclass(cls.B, StableContainer): - (findex, _, _) = cls.B._field_indices[key] - return 2**get_depth(cls.B.N) * 2 + findex - else: - findex = cls.B._field_indices[key] - n = len(cls.B.fields()) - return 2**get_depth(n) + findex - - -class OneOf(ComplexView): - def __class_getitem__(cls, b) -> Type["OneOf"]: - if not issubclass(b, StableContainer) and not issubclass(b, Container): - raise Exception(f"invalid OneOf base: {b}") - - class OneOfView(OneOf, b): - B = b - - @classmethod - def fields(cls): - return b.fields() - - OneOfView.__name__ = OneOfView.type_repr() - return OneOfView - - def __repr__(self): - return repr(self) - - @classmethod - def type_repr(cls) -> str: - return f"OneOf[{cls.B}]" - - @classmethod - def decode_bytes(cls: Type[B], bytez: bytes, *args, **kwargs) -> B: - stream = io.BytesIO() - stream.write(bytez) - stream.seek(0) - return cls.deserialize(stream, len(bytez), *args, **kwargs) - - @classmethod - def deserialize(cls: Type[B], stream: BinaryIO, scope: int, *args, **kwargs) -> B: - value = cls.B.deserialize(stream, scope) - v = cls.select_from_base(value, *args, **kwargs) - if not issubclass(v.B, cls.B): - raise Exception(f"unsupported select_from_base result: {v}") - return v(backing=value.get_backing()) + (findex, _, _) = cls._field_indices[key] + return 2**get_depth(cls.B.N) * 2 + findex diff --git a/assets/eip-7495/tests.py b/assets/eip-7495/tests.py index 89bffdb903e38..086ef4712b4f4 100644 --- a/assets/eip-7495/tests.py +++ b/assets/eip-7495/tests.py @@ -1,47 +1,31 @@ -from typing import Optional, Type +from typing import Optional from remerkleable.basic import uint8, uint16, uint32, uint64 from remerkleable.bitfields import Bitvector from remerkleable.complex import Container, List -from stable_container import OneOf, Profile, StableContainer +from stable_container import Profile, StableContainer -# Defines the common merkleization format and a portable serialization format class Shape(StableContainer[4]): side: Optional[uint16] - color: uint8 + color: Optional[uint8] radius: Optional[uint16] -# Inherits merkleization format from `Shape`, but is serialized more compactly class Square(Profile[Shape]): side: uint16 color: uint8 -# Inherits merkleization format from `Shape`, but is serialized more compactly class Circle(Profile[Shape]): color: uint8 radius: uint16 -class AnyShape(OneOf[Shape]): - @classmethod - def select_from_base(cls, value: Shape, circle_allowed = False) -> Type[Shape]: - if value.radius is not None: - assert circle_allowed - return Circle - if value.side is not None: - return Square - assert False - -# Defines a container with immutable scheme that contains two `StableContainer` class ShapePair(Container): shape_1: Shape shape_2: Shape -# Inherits merkleization format from `ShapePair`, and serializes more compactly -class SquarePair(Profile[ShapePair]): +class SquarePair(Container): shape_1: Square shape_2: Square -# Inherits merkleization format from `ShapePair`, and serializes more compactly -class CirclePair(Profile[ShapePair]): +class CirclePair(Container): shape_1: Circle shape_2: Circle @@ -50,6 +34,7 @@ class ShapePayload(Container): side: uint16 color: uint8 radius: uint16 + class ShapeRepr(Container): value: ShapePayload active_fields: Bitvector[4] @@ -58,18 +43,6 @@ class ShapePairRepr(Container): shape_1: ShapeRepr shape_2: ShapeRepr -class AnyShapePair(OneOf[ShapePair]): - @classmethod - def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[ShapePair]: - typ_1 = AnyShape.select_from_base(value.shape_1, circle_allowed) - typ_2 = AnyShape.select_from_base(value.shape_2, circle_allowed) - assert typ_1 == typ_2 - if typ_1 is Circle: - return CirclePair - if typ_1 is Square: - return SquarePair - assert False - # Square tests square_bytes_stable = bytes.fromhex("03420001") square_bytes_profile = bytes.fromhex("420001") @@ -89,9 +62,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert all(square.encode_bytes() == square_bytes_profile for square in squares) assert ( Square(backing=Shape.decode_bytes(square_bytes_stable).get_backing()) == - Square.decode_bytes(square_bytes_profile) == - AnyShape.decode_bytes(square_bytes_stable) == - AnyShape.decode_bytes(square_bytes_stable, circle_allowed = True) + Square.decode_bytes(square_bytes_profile) ) assert all(shape.hash_tree_root() == square_root for shape in shapes) assert all(square.hash_tree_root() == square_root for square in squares) @@ -128,9 +99,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert all(square.encode_bytes() == square_bytes_profile for square in squares) assert ( Square(backing=Shape.decode_bytes(square_bytes_stable).get_backing()) == - Square.decode_bytes(square_bytes_profile) == - AnyShape.decode_bytes(square_bytes_stable) == - AnyShape.decode_bytes(square_bytes_stable, circle_allowed = True) + Square.decode_bytes(square_bytes_profile) ) assert all(shape.hash_tree_root() == square_root for shape in shapes) assert all(square.hash_tree_root() == square_root for square in squares) @@ -169,8 +138,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert all(circle.encode_bytes() == circle_bytes_profile for circle in circles) assert ( Circle(backing=Shape.decode_bytes(circle_bytes_stable).get_backing()) == - Circle.decode_bytes(circle_bytes_profile) == - AnyShape.decode_bytes(circle_bytes_stable, circle_allowed = True) + Circle.decode_bytes(circle_bytes_profile) ) assert all(shape.hash_tree_root() == circle_root for shape in shapes) assert all(circle.hash_tree_root() == circle_root for circle in circles) @@ -191,11 +159,6 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert False except: pass -try: - circle = AnyShape.decode_bytes(circle_bytes_stable, circle_allowed = False) - assert False -except: - pass # SquarePair tests square_pair_bytes_stable = bytes.fromhex("080000000c0000000342000103690001") @@ -228,9 +191,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert all(pair.encode_bytes() == square_pair_bytes_profile for pair in square_pairs) assert ( SquarePair(backing=ShapePair.decode_bytes(square_pair_bytes_stable).get_backing()) == - SquarePair.decode_bytes(square_pair_bytes_profile) == - AnyShapePair.decode_bytes(square_pair_bytes_stable) == - AnyShapePair.decode_bytes(square_pair_bytes_stable, circle_allowed = True) + SquarePair.decode_bytes(square_pair_bytes_profile) ) assert all(pair.hash_tree_root() == square_pair_root for pair in shape_pairs) assert all(pair.hash_tree_root() == square_pair_root for pair in square_pairs) @@ -266,8 +227,7 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert all(pair.encode_bytes() == circle_pair_bytes_profile for pair in circle_pairs) assert ( CirclePair(backing=ShapePair.decode_bytes(circle_pair_bytes_stable).get_backing()) == - CirclePair.decode_bytes(circle_pair_bytes_profile) == - AnyShapePair.decode_bytes(circle_pair_bytes_stable, circle_allowed = True) + CirclePair.decode_bytes(circle_pair_bytes_profile) ) assert all(pair.hash_tree_root() == circle_pair_root for pair in shape_pairs) assert all(pair.hash_tree_root() == circle_pair_root for pair in circle_pairs) @@ -287,11 +247,6 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert False except: pass -try: - shape = AnyShape.decode_bytes(shape_bytes) - assert False -except: - pass shape = Shape(side=0x42, color=1, radius=0x42) shape_bytes = bytes.fromhex("074200014200") assert shape.encode_bytes() == shape_bytes @@ -306,16 +261,6 @@ def select_from_base(cls, value: ShapePair, circle_allowed = False) -> Type[Shap assert False except: pass -try: - shape = AnyShape.decode_bytes(shape_bytes) - assert False -except: - pass -try: - shape = AnyShape.decode_bytes("00") - assert False -except: - pass try: shape = Shape.decode_bytes("00") assert False @@ -369,13 +314,13 @@ class ShapeContainerRepr(Container): # basic container class Shape1(StableContainer[4]): side: Optional[uint16] - color: uint8 + color: Optional[uint8] radius: Optional[uint16] # basic container with different depth class Shape2(StableContainer[8]): side: Optional[uint16] - color: uint8 + color: Optional[uint8] radius: Optional[uint16] # basic container with variable fields